aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-26 11:54:30 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-26 11:54:30 +0800
commit35174f46b973c66a2e6894a12b3018d60e8414ec (patch)
tree5bdae0172159bc02ec3a470722bf959b14dd47ba /tensorflow
parentf0886f7269de900d226455d4831722f6fc94a71b (diff)
parent6666516f390f125ed70ddbd4e6f89b83d953c408 (diff)
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD64
-rw-r--r--tensorflow/api_template.__init__.py15
-rw-r--r--tensorflow/c/c_api.cc1
-rw-r--r--tensorflow/c/c_api_experimental.cc51
-rw-r--r--tensorflow/c/c_api_experimental.h9
-rw-r--r--tensorflow/c/c_api_function.cc1
-rw-r--r--tensorflow/c/eager/BUILD5
-rwxr-xr-xtensorflow/c/eager/c_api.cc18
-rwxr-xr-xtensorflow/c/eager/c_api.h7
-rw-r--r--tensorflow/c/eager/tape.h130
-rw-r--r--tensorflow/c/python_api.cc7
-rw-r--r--tensorflow/c/python_api.h13
-rw-r--r--tensorflow/cc/BUILD28
-rw-r--r--tensorflow/compiler/aot/tests/BUILD16
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py8
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt13
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc48
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl1
-rw-r--r--tensorflow/compiler/jit/BUILD64
-rw-r--r--tensorflow/compiler/jit/build_xla_launch_ops_pass.cc142
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc189
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.h (renamed from tensorflow/compiler/jit/build_xla_launch_ops_pass.h)10
-rw-r--r--tensorflow/compiler/jit/create_xla_launch_op.cc2
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.h29
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc18
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD7
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc276
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h87
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.cc499
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.h168
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc78
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc66
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc21
-rw-r--r--tensorflow/compiler/jit/node_matchers.cc458
-rw-r--r--tensorflow/compiler/jit/node_matchers.h197
-rw-r--r--tensorflow/compiler/jit/node_matchers_test.cc179
-rw-r--r--tensorflow/compiler/jit/ops/BUILD8
-rw-r--r--tensorflow/compiler/jit/ops/xla_ops.cc39
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc11
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc7
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc10
-rw-r--r--tensorflow/compiler/jit/xla_device.cc12
-rw-r--r--tensorflow/compiler/jit/xla_device.h12
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h10
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc11
-rw-r--r--tensorflow/compiler/jit/xla_interpreter_device.cc6
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc18
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h17
-rw-r--r--tensorflow/compiler/tests/BUILD21
-rw-r--r--tensorflow/compiler/tests/argminmax_test.py4
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py18
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl169
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py35
-rw-r--r--tensorflow/compiler/tests/dense_layer_test.py25
-rw-r--r--tensorflow/compiler/tests/gather_test.py14
-rw-r--r--tensorflow/compiler/tests/image_ops_test.py12
-rw-r--r--tensorflow/compiler/tests/jit_test.py48
-rw-r--r--tensorflow/compiler/tests/matrix_band_part_test.py190
-rw-r--r--tensorflow/compiler/tests/quantized_ops_test.py48
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py19
-rw-r--r--tensorflow/compiler/tests/reverse_sequence_op_test.py2
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py20
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py7
-rw-r--r--tensorflow/compiler/tests/ternary_ops_test.py3
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py7
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py43
-rw-r--r--tensorflow/compiler/tests/xla_test.py19
-rw-r--r--tensorflow/compiler/tf2xla/BUILD19
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD4
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc12
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc74
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.h13
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc168
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.h9
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc25
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc43
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.cc17
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc1
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.h13
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc22
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc33
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc509
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h69
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc551
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc60
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc76
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc10
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc85
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc30
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py8
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.cc14
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.h5
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc8
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc102
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h62
-rw-r--r--tensorflow/compiler/tf2xla/type_util.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc22
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc28
-rw-r--r--tensorflow/compiler/tf2xla/xla_cpu_backend.cc15
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h5
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc24
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h31
-rw-r--r--tensorflow/compiler/xla/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc44
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h7
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc13
-rw-r--r--tensorflow/compiler/xla/literal.cc20
-rw-r--r--tensorflow/compiler/xla/literal.h25
-rw-r--r--tensorflow/compiler/xla/literal_test.cc3
-rw-r--r--tensorflow/compiler/xla/protobuf_util.cc29
-rw-r--r--tensorflow/compiler/xla/protobuf_util.h4
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc6
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h3
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py19
-rw-r--r--tensorflow/compiler/xla/python/xla_client_test.py24
-rw-r--r--tensorflow/compiler/xla/reference_util.cc47
-rw-r--r--tensorflow/compiler/xla/reference_util.h50
-rw-r--r--tensorflow/compiler/xla/rpc/BUILD13
-rw-r--r--tensorflow/compiler/xla/rpc/grpc_service_main.cc21
-rw-r--r--tensorflow/compiler/xla/service/BUILD63
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc58
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/batch_dot_simplification.h2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.h4
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc6
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc26
-rw-r--r--tensorflow/compiler/xla/service/call_inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/compile_only_service.cc2
-rw-r--r--tensorflow/compiler/xla/service/conditional_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/convolution_feature_group_converter.h2
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD17
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc122
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h44
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc171
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc236
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h88
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc54
-rw-r--r--tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc18
-rw-r--r--tensorflow/compiler/xla/service/defuser.h2
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc2
-rw-r--r--tensorflow/compiler/xla/service/despecializer.h2
-rw-r--r--tensorflow/compiler/xla/service/dot_decomposer.h2
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc169
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph.h2
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD33
-rw-r--r--tensorflow/compiler/xla/service/gpu/backend_configs.proto14
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc43
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h25
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc159
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc58
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc194
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h55
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc278
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h37
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc94
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc118
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h56
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc30
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc35
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD60
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc283
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc205
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h62
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc130
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto7
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h17
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_constant_folding_test.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h30
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_isolator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc205
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc141
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h176
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h18
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc84
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc53
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce_test.cc72
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group.cc91
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group.h81
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_test.cc206
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc95
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc83
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_interface.h35
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc191
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h38
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc259
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h2
-rw-r--r--tensorflow/compiler/xla/service/implicit_broadcast_remover.h2
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/inliner.h2
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc260
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h36
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h2
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc4
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher.h762
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc183
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc10
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h2
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.h2
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc78
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.h2
-rw-r--r--tensorflow/compiler/xla/service/service.cc7
-rw-r--r--tensorflow/compiler/xla/service/service.h4
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc5
-rw-r--r--tensorflow/compiler/xla/service/source_map_util.cc66
-rw-r--r--tensorflow/compiler/xla/service/stream_pool.cc10
-rw-r--r--tensorflow/compiler/xla/service/stream_pool_test.cc34
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.h2
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc5
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.h2
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h2
-rw-r--r--tensorflow/compiler/xla/shape_util.cc20
-rw-r--r--tensorflow/compiler/xla/shape_util.h8
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/BUILD49
-rw-r--r--tensorflow/compiler/xla/tests/build_defs.bzl488
-rw-r--r--tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc10
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc78
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h63
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc158
-rw-r--r--tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc120
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc30
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc4
-rw-r--r--tensorflow/compiler/xla/xla.proto9
-rw-r--r--tensorflow/compiler/xla/xla_data.proto3
-rw-r--r--tensorflow/compiler/xrt/tests/BUILD6
-rw-r--r--tensorflow/contrib/BUILD24
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce_test.py2
-rw-r--r--tensorflow/contrib/autograph/BUILD8
-rw-r--r--tensorflow/contrib/autograph/README.md146
-rw-r--r--tensorflow/contrib/autograph/__init__.py50
-rw-r--r--tensorflow/contrib/autograph/utils/__init__.py29
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py29
-rw-r--r--tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py18
-rw-r--r--tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py29
-rw-r--r--tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py18
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py8
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD1
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py5
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc10
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc39
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc10
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py2
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py88
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py42
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/losses_test.py4
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py2
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state.py40
-rw-r--r--tensorflow/contrib/checkpoint/python/python_state_test.py5
-rw-r--r--tensorflow/contrib/cmake/README.md6
-rw-r--r--tensorflow/contrib/cmake/external/png.cmake3
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt4
-rw-r--r--tensorflow/contrib/cmake/python_protos.txt1
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake1
-rw-r--r--tensorflow/contrib/coder/python/ops/coder_ops_test.py2
-rw-r--r--tensorflow/contrib/compiler/BUILD11
-rw-r--r--tensorflow/contrib/compiler/jit_test.py2
-rw-r--r--tensorflow/contrib/compiler/xla.py442
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py25
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py2
-rw-r--r--tensorflow/contrib/data/__init__.py15
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc3
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc639
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc84
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD12
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py244
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py123
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py28
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD52
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py84
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py182
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py57
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py12
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py151
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py66
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py17
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py7
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py12
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py72
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py3
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py3
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py52
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py10
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py45
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py4
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py11
-rw-r--r--tensorflow/contrib/data/python/ops/sliding.py8
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py39
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py4
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py4
-rw-r--r--tensorflow/contrib/deprecated/summaries_test.py10
-rw-r--r--tensorflow/contrib/distribute/README.md2
-rw-r--r--tensorflow/contrib/distribute/__init__.py7
-rw-r--r--tensorflow/contrib/distribute/python/BUILD32
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py10
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py78
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py16
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_utils.py15
-rw-r--r--tensorflow/contrib/distribute/python/estimator_training_test.py248
-rw-r--r--tensorflow/contrib/distribute/python/input_ops_test.py20
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py164
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py30
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py6
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py24
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py1
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py53
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py228
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py90
-rw-r--r--tensorflow/contrib/distribute/python/single_loss_example.py6
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py7
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py1
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py19
-rw-r--r--tensorflow/contrib/distribute/python/values.py50
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py22
-rw-r--r--tensorflow/contrib/distributions/BUILD55
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py20
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/BUILD51
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py98
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py323
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py150
-rw-r--r--tensorflow/contrib/distributions/python/ops/autoregressive.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/batch_reshape.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/affine.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/permute.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/reshape.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/cauchy.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/deterministic.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/half_normal.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/independent.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson_lognormal.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/quantized_distribution.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/statistical_testing.py42
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py3
-rw-r--r--tensorflow/contrib/distributions/python/ops/wishart.py20
-rw-r--r--tensorflow/contrib/eager/README.md7
-rw-r--r--tensorflow/contrib/eager/python/BUILD14
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py26
-rw-r--r--tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md2
-rw-r--r--tensorflow/contrib/eager/python/parameter_server.py289
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py20
-rw-r--r--tensorflow/contrib/estimator/BUILD60
-rw-r--r--tensorflow/contrib/estimator/__init__.py3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py30
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py74
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py429
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py611
-rw-r--r--tensorflow/contrib/estimator/python/estimator/early_stopping.py39
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py1
-rw-r--r--tensorflow/contrib/factorization/BUILD1
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py2
-rw-r--r--tensorflow/contrib/fused_conv/BUILD8
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc3
-rw-r--r--tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py12
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl.py6
-rw-r--r--tensorflow/contrib/gan/python/namedtuples.py6
-rw-r--r--tensorflow/contrib/gan/python/train_test.py4
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.cc102
-rw-r--r--tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py28
-rw-r--r--tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py2
-rw-r--r--tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py4
-rw-r--r--tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py8
-rw-r--r--tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py2
-rw-r--r--tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py6
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses_test.py38
-rw-r--r--tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py12
-rw-r--r--tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py4
-rw-r--r--tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py6
-rw-r--r--tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py34
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops.py8
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers.py7
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py36
-rw-r--r--tensorflow/contrib/layers/python/layers/target_column.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py2
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py38
-rw-r--r--tensorflow/contrib/linalg/BUILD44
-rw-r--r--tensorflow/contrib/linalg/__init__.py58
-rw-r--r--tensorflow/contrib/linalg/python/__init__.py19
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py95
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py26
-rw-r--r--tensorflow/contrib/lite/README.md4
-rw-r--r--tensorflow/contrib/lite/build_def.bzl60
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h3
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.c25
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.h7
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc98
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.h22
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc26
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc28
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc10
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc19
-rw-r--r--tensorflow/contrib/lite/examples/android/app/README.md37
-rw-r--r--tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h3
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD12
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc50
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.h15
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.cc16
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental.h25
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc23
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_internal.h16
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_test.cc31
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc2
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc2
-rw-r--r--tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc2
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.cc14
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.h7
-rw-r--r--tensorflow/contrib/lite/g3doc/_book.yaml72
-rw-r--r--tensorflow/contrib/lite/g3doc/_index.yaml219
-rw-r--r--tensorflow/contrib/lite/g3doc/_project.yaml8
-rw-r--r--tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml6
-rw-r--r--tensorflow/contrib/lite/g3doc/devguide.md9
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.pngbin0 -> 10942 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.pngbin0 -> 578440 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.pngbin0 -> 7764 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.pngbin0 -> 16308 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.pngbin0 -> 20159 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.pngbin0 -> 35371 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.pngbin0 -> 12002 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.pngbin0 -> 25868 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.pngbin0 -> 7839 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.pngbin0 -> 27152 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.pngbin0 -> 17783 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.pngbin0 -> 17249 bytes
-rw-r--r--tensorflow/contrib/lite/g3doc/ios.md7
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md17
-rw-r--r--tensorflow/contrib/lite/g3doc/overview.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md186
-rw-r--r--tensorflow/contrib/lite/g3doc/performance_benchmarks.md174
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md11
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/android_build.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc1
-rw-r--r--tensorflow/contrib/lite/interpreter.h13
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md4
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java26
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java29
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/ovic/README.md2
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java2
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java (renamed from tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java)4
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java2
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java93
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java36
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java42
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java9
-rw-r--r--tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java15
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD16
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc141
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc91
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv_test.cc270
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise_test.cc9
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc66
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD14
-rw-r--r--tensorflow/contrib/lite/kernels/internal/compatibility.h32
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc107
-rw-r--r--tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc147
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h19
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h149
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h203
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h94
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h8
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h1633
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h108
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h135
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h460
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h1475
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/softmax.h202
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.cc56
-rw-r--r--tensorflow/contrib/lite/kernels/internal/test_util.h11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h99
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/op_macros.h46
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc52
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc8
-rw-r--r--tensorflow/contrib/lite/kernels/relu1_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h3
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/zeros_like.cc73
-rw-r--r--tensorflow/contrib/lite/kernels/zeros_like_test.cc78
-rw-r--r--tensorflow/contrib/lite/model.cc17
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.cc17
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.h8
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver_test.cc34
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h31
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc12
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc10
-rw-r--r--tensorflow/contrib/lite/python/convert.py43
-rw-r--r--tensorflow/contrib/lite/python/lite.py11
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py22
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py11
-rw-r--r--tensorflow/contrib/lite/schema/BUILD4
-rw-r--r--tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc2
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs19
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h529
-rw-r--r--tensorflow/contrib/lite/testing/BUILD5
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py59
-rw-r--r--tensorflow/contrib/lite/toco/BUILD6
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc28
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md2
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc117
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc81
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc80
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc86
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow_test.cc128
-rw-r--r--tensorflow/contrib/lite/toco/model.h17
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc32
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc4
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc13
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc70
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h3
-rw-r--r--tensorflow/contrib/lite/tools/make/Makefile1
-rwxr-xr-xtensorflow/contrib/lite/tools/make/download_dependencies.sh2
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc8
-rw-r--r--tensorflow/contrib/lite/tools/visualize.py2
-rw-r--r--tensorflow/contrib/lite/tutorials/post_training_quant.ipynb703
-rw-r--r--tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py16
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/contrib/meta_graph_transform/meta_graph_transform.py10
-rw-r--r--tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py10
-rw-r--r--tensorflow/contrib/metrics/python/metrics/classification_test.py28
-rw-r--r--tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py5
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc6
-rw-r--r--tensorflow/contrib/mpi/mpi_rendezvous_mgr.h2
-rw-r--r--tensorflow/contrib/nccl/BUILD21
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_rewrite.cc1
-rw-r--r--tensorflow/contrib/opt/BUILD22
-rw-r--r--tensorflow/contrib/opt/__init__.py5
-rw-r--r--tensorflow/contrib/opt/python/training/agn_optimizer.py262
-rw-r--r--tensorflow/contrib/opt/python/training/agn_optimizer_test.py281
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py6
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py15
-rw-r--r--tensorflow/contrib/predictor/BUILD3
-rw-r--r--tensorflow/contrib/predictor/saved_model_predictor.py19
-rw-r--r--tensorflow/contrib/quantization/README.md2
-rw-r--r--tensorflow/contrib/quantize/BUILD4
-rw-r--r--tensorflow/contrib/quantize/README.md158
-rw-r--r--tensorflow/contrib/quantize/python/common.py4
-rw-r--r--tensorflow/contrib/quantize/python/common_test.py59
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py94
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py15
-rw-r--r--tensorflow/contrib/quantize/python/quantize_parameterized_test.py282
-rw-r--r--tensorflow/contrib/rate/rate_test.py4
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py106
-rw-r--r--tensorflow/contrib/recurrent/python/ops/recurrent.py37
-rw-r--r--tensorflow/contrib/resampler/python/ops/resampler_ops_test.py8
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py32
-rw-r--r--tensorflow/contrib/saved_model/BUILD24
-rw-r--r--tensorflow/contrib/saved_model/__init__.py2
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/__init__.py1
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py2
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py42
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py191
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py39
-rw-r--r--tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py25
-rw-r--r--tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py175
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.cc9
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim.h6
-rw-r--r--tensorflow/contrib/session_bundle/bundle_shim_test.cc14
-rw-r--r--tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py6
-rw-r--r--tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py14
-rw-r--r--tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py8
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py46
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py2
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py2
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py2
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/util_test.py6
-rw-r--r--tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py2
-rw-r--r--tensorflow/contrib/summary/summary_ops_graph_test.py28
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD2
-rw-r--r--tensorflow/contrib/tensor_forest/client/eval_metrics_test.py8
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/stats_ops.cc3
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/tree_utils.cc12
-rw-r--r--tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py10
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_test.py2
-rw-r--r--tensorflow/contrib/tensorboard/BUILD31
-rw-r--r--tensorflow/contrib/tensorboard/db/loader.cc6
-rw-r--r--tensorflow/contrib/tensorboard/plugins/__init__.py2
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace.py167
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto60
-rw-r--r--tensorflow/contrib/tensorboard/plugins/trace/trace_test.py95
-rw-r--r--tensorflow/contrib/tensorrt/BUILD31
-rw-r--r--tensorflow/contrib/tensorrt/README.md2
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc14
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc19
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc13
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py319
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert_test.py293
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.cc18
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator.h2
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc21
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py6
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py8
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py6
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py35
-rw-r--r--tensorflow/contrib/text/python/ops/skip_gram_ops_test.py32
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py81
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py157
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py44
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils.py4
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py22
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py22
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py6
-rw-r--r--tensorflow/contrib/tpu/BUILD6
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc20
-rw-r--r--tensorflow/contrib/tpu/ops/replication_ops.cc11
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc22
-rw-r--r--tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc5
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto10
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py7
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py2
-rw-r--r--tensorflow/contrib/tpu/profiler/version.h2
-rw-r--r--tensorflow/contrib/tpu/proto/BUILD18
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_config.proto66
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto95
-rw-r--r--tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto75
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py27
-rw-r--r--tensorflow/contrib/tpu/python/tpu/device_assignment.py158
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py681
-rw-r--r--tensorflow/contrib/tpu/python/tpu/topology.py15
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py22
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config.py7
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_config_test.py2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py30
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py9
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_feed.py22
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_function.py8
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py4
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.cc81
-rw-r--r--tensorflow/contrib/verbs/rdma_mgr.h1
-rw-r--r--tensorflow/contrib/verbs/verbs_server_lib.cc5
-rw-r--r--tensorflow/core/BUILD100
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt34
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt22
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt31
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt49
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt45
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt14
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt41
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt30
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt17
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt19
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt38
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Substr.pbtxt6
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt45
-rw-r--r--tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt23
-rw-r--r--tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc21
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h16
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc5
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc9
-rw-r--r--tensorflow/core/common_runtime/device.h10
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc11
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc4
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc3
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc24
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc41
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc15
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h3
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h1
-rw-r--r--tensorflow/core/common_runtime/executor.cc139
-rw-r--r--tensorflow/core/common_runtime/function.cc5
-rw-r--r--tensorflow/core/common_runtime/gpu/cuda_host_allocator.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc50
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h45
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc146
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc15
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc30
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h22
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc80
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc293
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h36
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_test.cc19
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id.h32
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager.cc38
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc32
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_utils.h37
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.cc175
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.h58
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator_test.cc68
-rw-r--r--tensorflow/core/common_runtime/local_device.cc2
-rw-r--r--tensorflow/core/common_runtime/local_device.h3
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h50
-rw-r--r--tensorflow/core/common_runtime/parallel_concat_optimizer.cc6
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.cc45
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.h27
-rw-r--r--tensorflow/core/common_runtime/process_state.cc71
-rw-r--r--tensorflow/core/common_runtime/process_state.h15
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h16
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.cc1
-rw-r--r--tensorflow/core/common_runtime/session_ref.cc170
-rw-r--r--tensorflow/core/common_runtime/single_threaded_cpu_device.h1
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.cc182
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h137
-rw-r--r--tensorflow/core/common_runtime/tracing_device.h60
-rw-r--r--tensorflow/core/common_runtime/visitable_allocator.h79
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc13
-rw-r--r--tensorflow/core/example/example.proto8
-rw-r--r--tensorflow/core/example/feature_util.h5
-rw-r--r--tensorflow/core/framework/allocator.cc29
-rw-r--r--tensorflow/core/framework/allocator.h39
-rw-r--r--tensorflow/core/framework/allocator_registry.h1
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc1
-rw-r--r--tensorflow/core/framework/cancellation.cc10
-rw-r--r--tensorflow/core/framework/cancellation.h9
-rw-r--r--tensorflow/core/framework/cancellation_test.cc52
-rw-r--r--tensorflow/core/framework/dataset.cc1
-rw-r--r--tensorflow/core/framework/dataset.h96
-rw-r--r--tensorflow/core/framework/device_base.h13
-rw-r--r--tensorflow/core/framework/function.cc24
-rw-r--r--tensorflow/core/framework/function.h4
-rw-r--r--tensorflow/core/framework/function_testlib.cc34
-rw-r--r--tensorflow/core/framework/function_testlib.h3
-rw-r--r--tensorflow/core/framework/model.cc419
-rw-r--r--tensorflow/core/framework/model.h404
-rw-r--r--tensorflow/core/framework/node_def_util.cc20
-rw-r--r--tensorflow/core/framework/node_def_util.h8
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc42
-rw-r--r--tensorflow/core/framework/op_kernel.cc20
-rw-r--r--tensorflow/core/framework/op_kernel.h31
-rw-r--r--tensorflow/core/framework/op_segment.cc8
-rw-r--r--tensorflow/core/framework/op_segment.h4
-rw-r--r--tensorflow/core/framework/resource_mgr.cc2
-rw-r--r--tensorflow/core/framework/resource_mgr.h6
-rw-r--r--tensorflow/core/framework/tensor.cc134
-rw-r--r--tensorflow/core/framework/tensor.h26
-rw-r--r--tensorflow/core/framework/tensor_test.cc94
-rw-r--r--tensorflow/core/framework/tensor_util.h1
-rw-r--r--tensorflow/core/framework/types.h3
-rw-r--r--tensorflow/core/framework/variant.cc25
-rw-r--r--tensorflow/core/framework/variant.h60
-rw-r--r--tensorflow/core/framework/variant_encode_decode.h32
-rw-r--r--tensorflow/core/framework/variant_op_copy_test.cc6
-rw-r--r--tensorflow/core/framework/variant_op_registry.cc85
-rw-r--r--tensorflow/core/framework/variant_op_registry.h216
-rw-r--r--tensorflow/core/framework/variant_op_registry_test.cc96
-rw-r--r--tensorflow/core/framework/variant_tensor_data.cc22
-rw-r--r--tensorflow/core/framework/variant_tensor_data.h10
-rw-r--r--tensorflow/core/framework/variant_test.cc15
-rw-r--r--tensorflow/core/graph/graph_constructor.cc8
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc9
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc1
-rw-r--r--tensorflow/core/graph/testlib.h2
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc1
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc6
-rw-r--r--tensorflow/core/grappler/clusters/utils.cc13
-rw-r--r--tensorflow/core/grappler/clusters/utils.h2
-rw-r--r--tensorflow/core/grappler/clusters/utils_test.cc22
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc199
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc385
-rw-r--r--tensorflow/core/grappler/costs/utils.cc16
-rw-r--r--tensorflow/core/grappler/costs/utils.h2
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc8
-rw-r--r--tensorflow/core/grappler/graph_view.cc33
-rw-r--r--tensorflow/core/grappler/graph_view.h10
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc83
-rw-r--r--tensorflow/core/grappler/inputs/utils.cc7
-rw-r--r--tensorflow/core/grappler/inputs/utils.h4
-rw-r--r--tensorflow/core/grappler/op_types.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD75
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc150
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc65
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc23
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD122
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc196
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.h108
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils_test.cc164
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc94
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h27
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc82
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc106
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.h47
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc94
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc43
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD69
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc54
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc61
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h49
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h75
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc50
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc292
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h90
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc600
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc48
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc35
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc264
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h62
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc194
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc3
-rw-r--r--tensorflow/core/grappler/utils.cc30
-rw-r--r--tensorflow/core/grappler/utils.h6
-rw-r--r--tensorflow/core/grappler/utils/BUILD29
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc9
-rw-r--r--tensorflow/core/grappler/utils/scc.h7
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.cc)2
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes.h (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.h)6
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes_test.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc)2
-rw-r--r--tensorflow/core/grappler/utils_test.cc36
-rw-r--r--tensorflow/core/kernels/BUILD104
-rw-r--r--tensorflow/core/kernels/bias_op.cc13
-rw-r--r--tensorflow/core/kernels/bincount_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD16
-rw-r--r--tensorflow/core/kernels/boosted_trees/boosted_trees.proto13
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc38
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantile_ops.cc453
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/BUILD4
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h96
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc26
-rw-r--r--tensorflow/core/kernels/conv_2d.h45
-rw-r--r--tensorflow/core/kernels/conv_3d.h43
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc6
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc11
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.h10
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc1330
-rw-r--r--tensorflow/core/kernels/conv_ops.cc19
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc20
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h6
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc81
-rw-r--r--tensorflow/core/kernels/data/BUILD29
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc156
-rw-r--r--tensorflow/core/kernels/data/captured_function.h25
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc37
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h10
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc61
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc44
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc51
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc12
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc41
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc82
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc14
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc234
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc183
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc633
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc22
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc207
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc55
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc61
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc7
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.cc13
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc38
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc19
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc215
-rw-r--r--tensorflow/core/kernels/decode_bmp_op.cc7
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc3
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h311
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h41
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc31
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h1438
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h342
-rw-r--r--tensorflow/core/kernels/eigen_volume_patch.h1
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.cc197
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.h58
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc38
-rw-r--r--tensorflow/core/kernels/fuzzing/BUILD2
-rw-r--r--tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc45
-rw-r--r--tensorflow/core/kernels/gather_functor.h1
-rw-r--r--tensorflow/core/kernels/histogram_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/list_kernels.cc12
-rw-r--r--tensorflow/core/kernels/list_kernels.cu.cc3
-rw-r--r--tensorflow/core/kernels/logging_ops.cc54
-rw-r--r--tensorflow/core/kernels/logging_ops_test.cc22
-rw-r--r--tensorflow/core/kernels/mirror_pad_op.h1
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops_test.cc407
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc2
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc4
-rw-r--r--tensorflow/core/kernels/queue_base.h4
-rw-r--r--tensorflow/core/kernels/queue_ops.cc2
-rw-r--r--tensorflow/core/kernels/random_op.cc10
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h10
-rw-r--r--tensorflow/core/kernels/reduction_ops_max.cc2
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc2
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.cc5
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc1
-rw-r--r--tensorflow/core/kernels/searchsorted_op.cc249
-rw-r--r--tensorflow/core/kernels/searchsorted_op.h52
-rw-r--r--tensorflow/core/kernels/searchsorted_op_gpu.cu.cc126
-rw-r--r--tensorflow/core/kernels/shape_op_test.cc10
-rw-r--r--tensorflow/core/kernels/split_lib_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/split_op.cc7
-rw-r--r--tensorflow/core/kernels/stack_ops.cc26
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc1
-rw-r--r--tensorflow/core/kernels/string_format_op.cc65
-rw-r--r--tensorflow/core/kernels/string_format_op_test.cc66
-rw-r--r--tensorflow/core/kernels/string_length_op.cc23
-rw-r--r--tensorflow/core/kernels/string_util.cc63
-rw-r--r--tensorflow/core/kernels/string_util.h45
-rw-r--r--tensorflow/core/kernels/substr_op.cc50
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc105
-rw-r--r--tensorflow/core/kernels/tensor_array.cc3
-rw-r--r--tensorflow/core/kernels/tensor_array.h3
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc5
-rw-r--r--tensorflow/core/kernels/topk_op_gpu.cu.cc6
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc1
-rw-r--r--tensorflow/core/kernels/unravel_index_op.cc10
-rw-r--r--tensorflow/core/kernels/where_op_gpu.cu.h8
-rw-r--r--tensorflow/core/lib/core/status.h1
-rw-r--r--tensorflow/core/lib/core/stringpiece.h6
-rw-r--r--tensorflow/core/lib/core/threadpool.cc49
-rw-r--r--tensorflow/core/lib/core/threadpool.h14
-rw-r--r--tensorflow/core/lib/core/threadpool_test.cc61
-rw-r--r--tensorflow/core/lib/io/block_builder.h1
-rw-r--r--tensorflow/core/lib/io/path.h1
-rw-r--r--tensorflow/core/lib/io/record_reader.cc53
-rw-r--r--tensorflow/core/lib/io/record_reader.h25
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc7
-rw-r--r--tensorflow/core/lib/io/recordio_test.cc2
-rw-r--r--tensorflow/core/lib/io/table_test.cc2
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.cc2
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.h2
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.cc6
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.h1
-rw-r--r--tensorflow/core/lib/monitoring/metric_def.h1
-rw-r--r--tensorflow/core/lib/png/png_io.h1
-rw-r--r--tensorflow/core/lib/wav/wav_io.cc5
-rw-r--r--tensorflow/core/ops/array_ops.cc138
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc127
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt628
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc9
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops_test.cc11
-rw-r--r--tensorflow/core/ops/dataset_ops.cc54
-rw-r--r--tensorflow/core/ops/logging_ops.cc19
-rw-r--r--tensorflow/core/ops/nn_ops.cc10
-rw-r--r--tensorflow/core/ops/ops.pbtxt565
-rw-r--r--tensorflow/core/ops/parsing_ops.cc7
-rw-r--r--tensorflow/core/ops/parsing_ops_test.cc7
-rw-r--r--tensorflow/core/ops/string_ops.cc28
-rw-r--r--tensorflow/core/platform/abi.cc4
-rw-r--r--tensorflow/core/platform/abi.h3
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc5
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system.h2
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system_test.cc2
-rw-r--r--tensorflow/core/platform/cord.h26
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl86
-rw-r--r--tensorflow/core/platform/default/cord.h21
-rw-r--r--tensorflow/core/platform/default/device_tracer.cc7
-rw-r--r--tensorflow/core/platform/env_test.cc7
-rw-r--r--tensorflow/core/platform/file_system.h11
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc2
-rw-r--r--tensorflow/core/platform/posix/posix_file_system.cc2
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc2
-rw-r--r--tensorflow/core/platform/tracing.h5
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.cc2
-rw-r--r--tensorflow/core/protobuf/config.proto2
-rw-r--r--tensorflow/core/protobuf/replay_log.proto47
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto2
-rw-r--r--tensorflow/core/public/version.h4
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h31
-rw-r--r--tensorflow/core/util/mkl_util.h5
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc10
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h4
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h14
-rw-r--r--tensorflow/core/util/tensor_bundle/naming.h1
-rw-r--r--tensorflow/core/util/work_sharder.cc2
-rw-r--r--tensorflow/core/util/work_sharder.h3
-rw-r--r--tensorflow/examples/autograph/integration_tests/BUILD (renamed from tensorflow/contrib/autograph/examples/integration_tests/BUILD)0
-rw-r--r--tensorflow/examples/autograph/integration_tests/errors_test.py (renamed from tensorflow/contrib/autograph/examples/integration_tests/errors_test.py)30
-rw-r--r--tensorflow/examples/autograph/integration_tests/keras_test.py (renamed from tensorflow/contrib/autograph/examples/integration_tests/keras_test.py)2
-rw-r--r--tensorflow/examples/autograph/integration_tests/list_literals_test.py (renamed from tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py)2
-rw-r--r--tensorflow/examples/learn/text_classification_character_cnn.py2
-rw-r--r--tensorflow/examples/speech_commands/freeze_test.py6
-rw-r--r--tensorflow/examples/speech_commands/input_data_test.py4
-rw-r--r--tensorflow/examples/speech_commands/label_wav_test.py2
-rw-r--r--tensorflow/examples/speech_commands/models_test.py12
-rw-r--r--tensorflow/examples/tutorials/mnist/BUILD12
-rw-r--r--tensorflow/go/README.md6
-rw-r--r--tensorflow/go/op/wrappers.go2721
-rw-r--r--tensorflow/java/README.md7
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/spark-tensorflow-connector/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow-hadoop/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/python/BUILD43
-rw-r--r--tensorflow/python/autograph/BUILD31
-rw-r--r--tensorflow/python/autograph/CONTRIBUTING.md (renamed from tensorflow/contrib/autograph/CONTRIBUTING.md)9
-rw-r--r--tensorflow/python/autograph/LIMITATIONS.md (renamed from tensorflow/contrib/autograph/LIMITATIONS.md)0
-rw-r--r--tensorflow/python/autograph/README.md143
-rw-r--r--tensorflow/python/autograph/STYLE_GUIDE.md (renamed from tensorflow/contrib/autograph/STYLE_GUIDE.md)0
-rw-r--r--tensorflow/python/autograph/__init__.py70
-rw-r--r--tensorflow/python/autograph/converters/BUILD (renamed from tensorflow/contrib/autograph/converters/BUILD)54
-rw-r--r--tensorflow/python/autograph/converters/__init__.py (renamed from tensorflow/contrib/autograph/converters/__init__.py)0
-rw-r--r--tensorflow/python/autograph/converters/asserts.py (renamed from tensorflow/contrib/autograph/converters/asserts.py)4
-rw-r--r--tensorflow/python/autograph/converters/asserts_test.py (renamed from tensorflow/contrib/autograph/converters/asserts_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/break_statements.py (renamed from tensorflow/contrib/autograph/converters/break_statements.py)8
-rw-r--r--tensorflow/python/autograph/converters/break_statements_test.py (renamed from tensorflow/contrib/autograph/converters/break_statements_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions.py (renamed from tensorflow/contrib/autograph/converters/builtin_functions.py)17
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions_test.py (renamed from tensorflow/contrib/autograph/converters/builtin_functions_test.py)20
-rw-r--r--tensorflow/python/autograph/converters/call_trees.py (renamed from tensorflow/contrib/autograph/converters/call_trees.py)23
-rw-r--r--tensorflow/python/autograph/converters/call_trees_test.py (renamed from tensorflow/contrib/autograph/converters/call_trees_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/conditional_expressions.py (renamed from tensorflow/contrib/autograph/converters/conditional_expressions.py)8
-rw-r--r--tensorflow/python/autograph/converters/conditional_expressions_test.py (renamed from tensorflow/contrib/autograph/converters/conditional_expressions_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/continue_statements.py (renamed from tensorflow/contrib/autograph/converters/continue_statements.py)8
-rw-r--r--tensorflow/python/autograph/converters/continue_statements_test.py (renamed from tensorflow/contrib/autograph/converters/continue_statements_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/control_flow.py (renamed from tensorflow/contrib/autograph/converters/control_flow.py)12
-rw-r--r--tensorflow/python/autograph/converters/control_flow_test.py (renamed from tensorflow/contrib/autograph/converters/control_flow_test.py)6
-rw-r--r--tensorflow/python/autograph/converters/decorators.py (renamed from tensorflow/contrib/autograph/converters/decorators.py)4
-rw-r--r--tensorflow/python/autograph/converters/decorators_test.py (renamed from tensorflow/contrib/autograph/converters/decorators_test.py)16
-rw-r--r--tensorflow/python/autograph/converters/directives.py (renamed from tensorflow/contrib/autograph/converters/directives.py)6
-rw-r--r--tensorflow/python/autograph/converters/directives_test.py (renamed from tensorflow/contrib/autograph/converters/directives_test.py)12
-rw-r--r--tensorflow/python/autograph/converters/error_handlers.py (renamed from tensorflow/contrib/autograph/converters/error_handlers.py)6
-rw-r--r--tensorflow/python/autograph/converters/error_handlers_test.py (renamed from tensorflow/contrib/autograph/converters/error_handlers_test.py)10
-rw-r--r--tensorflow/python/autograph/converters/list_comprehensions.py (renamed from tensorflow/contrib/autograph/converters/list_comprehensions.py)4
-rw-r--r--tensorflow/python/autograph/converters/list_comprehensions_test.py (renamed from tensorflow/contrib/autograph/converters/list_comprehensions_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/lists.py (renamed from tensorflow/contrib/autograph/converters/lists.py)12
-rw-r--r--tensorflow/python/autograph/converters/lists_test.py (renamed from tensorflow/contrib/autograph/converters/lists_test.py)12
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions.py (renamed from tensorflow/contrib/autograph/converters/logical_expressions.py)29
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions_test.py (renamed from tensorflow/contrib/autograph/converters/logical_expressions_test.py)14
-rw-r--r--tensorflow/python/autograph/converters/name_scopes.py (renamed from tensorflow/contrib/autograph/converters/name_scopes.py)4
-rw-r--r--tensorflow/python/autograph/converters/name_scopes_test.py (renamed from tensorflow/contrib/autograph/converters/name_scopes_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/return_statements.py (renamed from tensorflow/contrib/autograph/converters/return_statements.py)10
-rw-r--r--tensorflow/python/autograph/converters/return_statements_test.py (renamed from tensorflow/contrib/autograph/converters/return_statements_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/side_effect_guards.py (renamed from tensorflow/contrib/autograph/converters/side_effect_guards.py)12
-rw-r--r--tensorflow/python/autograph/converters/side_effect_guards_test.py (renamed from tensorflow/contrib/autograph/converters/side_effect_guards_test.py)4
-rw-r--r--tensorflow/python/autograph/converters/slices.py (renamed from tensorflow/contrib/autograph/converters/slices.py)6
-rw-r--r--tensorflow/python/autograph/converters/slices_test.py (renamed from tensorflow/contrib/autograph/converters/slices_test.py)12
-rw-r--r--tensorflow/python/autograph/core/BUILD (renamed from tensorflow/contrib/autograph/core/BUILD)14
-rw-r--r--tensorflow/python/autograph/core/config.py (renamed from tensorflow/contrib/autograph/core/config.py)4
-rw-r--r--tensorflow/python/autograph/core/converter.py (renamed from tensorflow/contrib/autograph/core/converter.py)34
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py (renamed from tensorflow/contrib/autograph/core/converter_testing.py)30
-rw-r--r--tensorflow/python/autograph/core/errors.py (renamed from tensorflow/contrib/autograph/core/errors.py)3
-rw-r--r--tensorflow/python/autograph/core/errors_test.py (renamed from tensorflow/contrib/autograph/core/errors_test.py)10
-rw-r--r--tensorflow/python/autograph/core/naming.py (renamed from tensorflow/contrib/autograph/core/naming.py)2
-rw-r--r--tensorflow/python/autograph/core/naming_test.py (renamed from tensorflow/contrib/autograph/core/naming_test.py)2
-rw-r--r--tensorflow/python/autograph/docs/pyfunc_dtypes.md (renamed from tensorflow/contrib/autograph/docs/pyfunc_dtypes.md)0
-rw-r--r--tensorflow/python/autograph/impl/BUILD (renamed from tensorflow/contrib/autograph/impl/BUILD)14
-rw-r--r--tensorflow/python/autograph/impl/api.py (renamed from tensorflow/contrib/autograph/impl/api.py)110
-rw-r--r--tensorflow/python/autograph/impl/api_test.py (renamed from tensorflow/contrib/autograph/impl/api_test.py)60
-rw-r--r--tensorflow/python/autograph/impl/conversion.py (renamed from tensorflow/contrib/autograph/impl/conversion.py)57
-rw-r--r--tensorflow/python/autograph/impl/conversion_test.py (renamed from tensorflow/contrib/autograph/impl/conversion_test.py)10
-rw-r--r--tensorflow/python/autograph/lang/BUILD (renamed from tensorflow/contrib/autograph/lang/BUILD)2
-rw-r--r--tensorflow/python/autograph/lang/directives.py (renamed from tensorflow/contrib/autograph/lang/directives.py)0
-rw-r--r--tensorflow/python/autograph/lang/special_functions.py (renamed from tensorflow/contrib/autograph/lang/special_functions.py)2
-rw-r--r--tensorflow/python/autograph/lang/special_functions_test.py (renamed from tensorflow/contrib/autograph/lang/special_functions_test.py)6
-rw-r--r--tensorflow/python/autograph/operators/BUILD (renamed from tensorflow/contrib/autograph/operators/BUILD)3
-rw-r--r--tensorflow/python/autograph/operators/__init__.py (renamed from tensorflow/contrib/autograph/operators/__init__.py)32
-rw-r--r--tensorflow/python/autograph/operators/control_flow.py (renamed from tensorflow/contrib/autograph/operators/control_flow.py)2
-rw-r--r--tensorflow/python/autograph/operators/control_flow_test.py (renamed from tensorflow/contrib/autograph/operators/control_flow_test.py)2
-rw-r--r--tensorflow/python/autograph/operators/data_structures.py (renamed from tensorflow/contrib/autograph/operators/data_structures.py)0
-rw-r--r--tensorflow/python/autograph/operators/data_structures_test.py (renamed from tensorflow/contrib/autograph/operators/data_structures_test.py)2
-rw-r--r--tensorflow/python/autograph/operators/dispatch_context.py (renamed from tensorflow/contrib/autograph/operators/dispatch_context.py)0
-rw-r--r--tensorflow/python/autograph/operators/py_builtins.py (renamed from tensorflow/contrib/autograph/operators/py_builtins.py)11
-rw-r--r--tensorflow/python/autograph/operators/py_builtins_test.py (renamed from tensorflow/contrib/autograph/operators/py_builtins_test.py)27
-rw-r--r--tensorflow/python/autograph/operators/slices.py (renamed from tensorflow/contrib/autograph/operators/slices.py)0
-rw-r--r--tensorflow/python/autograph/operators/slices_test.py (renamed from tensorflow/contrib/autograph/operators/slices_test.py)6
-rw-r--r--tensorflow/python/autograph/pyct/BUILD (renamed from tensorflow/contrib/autograph/pyct/BUILD)0
-rw-r--r--tensorflow/python/autograph/pyct/__init__.py (renamed from tensorflow/contrib/autograph/pyct/__init__.py)0
-rw-r--r--tensorflow/python/autograph/pyct/anno.py (renamed from tensorflow/contrib/autograph/pyct/anno.py)0
-rw-r--r--tensorflow/python/autograph/pyct/anno_test.py (renamed from tensorflow/contrib/autograph/pyct/anno_test.py)2
-rw-r--r--tensorflow/python/autograph/pyct/ast_util.py (renamed from tensorflow/contrib/autograph/pyct/ast_util.py)4
-rw-r--r--tensorflow/python/autograph/pyct/ast_util_test.py (renamed from tensorflow/contrib/autograph/pyct/ast_util_test.py)10
-rw-r--r--tensorflow/python/autograph/pyct/cfg.py (renamed from tensorflow/contrib/autograph/pyct/cfg.py)15
-rw-r--r--tensorflow/python/autograph/pyct/cfg_test.py (renamed from tensorflow/contrib/autograph/pyct/cfg_test.py)4
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/BUILD (renamed from tensorflow/contrib/autograph/pyct/common_transformers/BUILD)2
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/__init__.py (renamed from tensorflow/contrib/autograph/pyct/common_transformers/__init__.py)0
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/anf.py (renamed from tensorflow/contrib/autograph/pyct/common_transformers/anf.py)4
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/anf_test.py (renamed from tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py)8
-rw-r--r--tensorflow/python/autograph/pyct/compiler.py (renamed from tensorflow/contrib/autograph/pyct/compiler.py)15
-rw-r--r--tensorflow/python/autograph/pyct/compiler_test.py (renamed from tensorflow/contrib/autograph/pyct/compiler_test.py)4
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils.py (renamed from tensorflow/contrib/autograph/pyct/inspect_utils.py)0
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils_test.py (renamed from tensorflow/contrib/autograph/pyct/inspect_utils_test.py)2
-rw-r--r--tensorflow/python/autograph/pyct/origin_info.py (renamed from tensorflow/contrib/autograph/pyct/origin_info.py)8
-rw-r--r--tensorflow/python/autograph/pyct/origin_info_test.py (renamed from tensorflow/contrib/autograph/pyct/origin_info_test.py)67
-rw-r--r--tensorflow/python/autograph/pyct/parser.py (renamed from tensorflow/contrib/autograph/pyct/parser.py)15
-rw-r--r--tensorflow/python/autograph/pyct/parser_test.py (renamed from tensorflow/contrib/autograph/pyct/parser_test.py)18
-rw-r--r--tensorflow/python/autograph/pyct/pretty_printer.py (renamed from tensorflow/contrib/autograph/pyct/pretty_printer.py)0
-rw-r--r--tensorflow/python/autograph/pyct/pretty_printer_test.py (renamed from tensorflow/contrib/autograph/pyct/pretty_printer_test.py)2
-rw-r--r--tensorflow/python/autograph/pyct/qual_names.py (renamed from tensorflow/contrib/autograph/pyct/qual_names.py)4
-rw-r--r--tensorflow/python/autograph/pyct/qual_names_test.py (renamed from tensorflow/contrib/autograph/pyct/qual_names_test.py)10
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/BUILD (renamed from tensorflow/contrib/autograph/pyct/static_analysis/BUILD)16
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/__init__.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/__init__.py)0
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/activity.py)14
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py)14
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/annos.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/annos.py)0
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/live_values.py)18
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py)18
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/liveness.py)8
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py)14
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py)8
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py)14
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/type_info.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/type_info.py)6
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/type_info_test.py (renamed from tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py)18
-rw-r--r--tensorflow/python/autograph/pyct/templates.py (renamed from tensorflow/contrib/autograph/pyct/templates.py)19
-rw-r--r--tensorflow/python/autograph/pyct/templates_test.py (renamed from tensorflow/contrib/autograph/pyct/templates_test.py)18
-rw-r--r--tensorflow/python/autograph/pyct/testing/BUILD (renamed from tensorflow/contrib/autograph/pyct/testing/BUILD)6
-rw-r--r--tensorflow/python/autograph/pyct/testing/codegen.py (renamed from tensorflow/contrib/autograph/pyct/testing/codegen.py)2
-rw-r--r--tensorflow/python/autograph/pyct/testing/codegen_test.py (renamed from tensorflow/contrib/autograph/pyct/testing/codegen_test.py)4
-rw-r--r--tensorflow/python/autograph/pyct/transformer.py (renamed from tensorflow/contrib/autograph/pyct/transformer.py)6
-rw-r--r--tensorflow/python/autograph/pyct/transformer_test.py (renamed from tensorflow/contrib/autograph/pyct/transformer_test.py)6
-rw-r--r--tensorflow/python/autograph/utils/BUILD (renamed from tensorflow/contrib/autograph/utils/BUILD)2
-rw-r--r--tensorflow/python/autograph/utils/__init__.py (renamed from tensorflow/contrib/tensorboard/plugins/trace/__init__.py)13
-rw-r--r--tensorflow/python/autograph/utils/context_managers.py (renamed from tensorflow/contrib/autograph/utils/context_managers.py)0
-rw-r--r--tensorflow/python/autograph/utils/context_managers_test.py (renamed from tensorflow/contrib/autograph/utils/context_managers_test.py)2
-rw-r--r--tensorflow/python/autograph/utils/misc.py (renamed from tensorflow/contrib/autograph/utils/misc.py)0
-rw-r--r--tensorflow/python/autograph/utils/misc_test.py (renamed from tensorflow/contrib/autograph/utils/misc_test.py)2
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch.py (renamed from tensorflow/contrib/autograph/utils/multiple_dispatch.py)12
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch_test.py (renamed from tensorflow/contrib/autograph/utils/multiple_dispatch_test.py)31
-rw-r--r--tensorflow/python/autograph/utils/py_func.py (renamed from tensorflow/contrib/autograph/utils/py_func.py)0
-rw-r--r--tensorflow/python/autograph/utils/py_func_test.py (renamed from tensorflow/contrib/autograph/utils/py_func_test.py)2
-rw-r--r--tensorflow/python/autograph/utils/tensor_list.py (renamed from tensorflow/contrib/autograph/utils/tensor_list.py)0
-rw-r--r--tensorflow/python/autograph/utils/tensor_list_test.py (renamed from tensorflow/contrib/autograph/utils/tensor_list_test.py)2
-rw-r--r--tensorflow/python/autograph/utils/tensors.py (renamed from tensorflow/contrib/autograph/utils/tensors.py)0
-rw-r--r--tensorflow/python/autograph/utils/tensors_test.py (renamed from tensorflow/contrib/autograph/utils/tensors_test.py)2
-rw-r--r--tensorflow/python/autograph/utils/testing.py (renamed from tensorflow/contrib/autograph/utils/testing.py)0
-rw-r--r--tensorflow/python/autograph/utils/type_check.py (renamed from tensorflow/contrib/autograph/utils/type_check.py)0
-rw-r--r--tensorflow/python/autograph/utils/type_check_test.py (renamed from tensorflow/contrib/autograph/utils/type_check_test.py)2
-rw-r--r--tensorflow/python/client/session.py46
-rw-r--r--tensorflow/python/client/session_ref.cc525
-rw-r--r--tensorflow/python/client/session_ref.h (renamed from tensorflow/core/common_runtime/session_ref.h)15
-rw-r--r--tensorflow/python/client/session_test.py82
-rw-r--r--tensorflow/python/client/tf_session.i4
-rw-r--r--tensorflow/python/client/tf_session_helper.cc2
-rw-r--r--tensorflow/python/client/timeline.py3
-rw-r--r--tensorflow/python/client/timeline_test.py4
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/BUILD1
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD51
-rw-r--r--tensorflow/python/data/kernel_tests/inputs_test.py148
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py6
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py24
-rw-r--r--tensorflow/python/data/kernel_tests/multi_device_iterator_test.py190
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py176
-rw-r--r--tensorflow/python/data/kernel_tests/window_dataset_op_test.py295
-rw-r--r--tensorflow/python/data/ops/BUILD19
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py201
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py13
-rw-r--r--tensorflow/python/data/ops/multi_device_iterator_ops.py231
-rw-r--r--tensorflow/python/data/ops/optional_ops.py150
-rw-r--r--tensorflow/python/data/ops/readers.py12
-rw-r--r--tensorflow/python/data/util/nest.py34
-rw-r--r--tensorflow/python/data/util/structure.py131
-rw-r--r--tensorflow/python/data/util/structure_test.py36
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py3
-rw-r--r--tensorflow/python/debug/lib/debug_graph_reconstruction_test.py3
-rw-r--r--tensorflow/python/debug/lib/session_debug_grpc_test.py2
-rw-r--r--tensorflow/python/distribute/estimator_training.py21
-rw-r--r--tensorflow/python/eager/BUILD36
-rw-r--r--tensorflow/python/eager/backprop.py43
-rw-r--r--tensorflow/python/eager/backprop_test.py12
-rw-r--r--tensorflow/python/eager/def_function.py235
-rw-r--r--tensorflow/python/eager/def_function_test.py87
-rw-r--r--tensorflow/python/eager/function.py393
-rw-r--r--tensorflow/python/eager/function_test.py409
-rw-r--r--tensorflow/python/eager/imperative_grad.py5
-rw-r--r--tensorflow/python/eager/pywrap_tensor.cc41
-rw-r--r--tensorflow/python/eager/pywrap_tensor.h5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc473
-rw-r--r--tensorflow/python/eager/pywrap_tfe_test.py25
-rw-r--r--tensorflow/python/estimator/BUILD31
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py383
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py704
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_utils.py80
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_utils_test.py187
-rw-r--r--tensorflow/python/estimator/estimator.py26
-rw-r--r--tensorflow/python/estimator/export/export_test.py2
-rw-r--r--tensorflow/python/estimator/keras_test.py166
-rw-r--r--tensorflow/python/estimator/model_fn.py6
-rw-r--r--tensorflow/python/feature_column/BUILD2
-rw-r--r--tensorflow/python/feature_column/feature_column.py25
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py34
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py8
-rw-r--r--tensorflow/python/framework/function.py26
-rw-r--r--tensorflow/python/framework/function_test.py27
-rw-r--r--tensorflow/python/framework/load_library.py65
-rw-r--r--tensorflow/python/framework/ops.py23
-rw-r--r--tensorflow/python/framework/ops_test.py12
-rw-r--r--tensorflow/python/framework/test_util.py166
-rw-r--r--tensorflow/python/framework/test_util_test.py8
-rwxr-xr-xtensorflow/python/keras/BUILD12
-rw-r--r--tensorflow/python/keras/applications/__init__.py3
-rw-r--r--tensorflow/python/keras/backend.py79
-rw-r--r--tensorflow/python/keras/backend_test.py5
-rw-r--r--tensorflow/python/keras/callbacks.py16
-rw-r--r--tensorflow/python/keras/callbacks_test.py34
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py16
-rw-r--r--tensorflow/python/keras/engine/saving_test.py7
-rw-r--r--tensorflow/python/keras/engine/topology_test.py2
-rw-r--r--tensorflow/python/keras/engine/training.py422
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py47
-rw-r--r--tensorflow/python/keras/engine/training_eager.py2
-rw-r--r--tensorflow/python/keras/engine/training_test.py70
-rw-r--r--tensorflow/python/keras/engine/training_utils.py76
-rw-r--r--tensorflow/python/keras/layers/advanced_activations.py16
-rw-r--r--tensorflow/python/keras/layers/advanced_activations_test.py8
-rw-r--r--tensorflow/python/keras/layers/convolutional.py71
-rw-r--r--tensorflow/python/keras/layers/convolutional_test.py4
-rw-r--r--tensorflow/python/keras/layers/embeddings.py29
-rw-r--r--tensorflow/python/keras/layers/embeddings_test.py13
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py4
-rw-r--r--tensorflow/python/keras/metrics.py43
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py14
-rw-r--r--tensorflow/python/keras/optimizers_test.py37
-rw-r--r--tensorflow/python/keras/testing_utils.py6
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py6
-rw-r--r--tensorflow/python/keras/utils/data_utils.py8
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py10
-rw-r--r--tensorflow/python/keras/wrappers/scikit_learn_test.py12
-rw-r--r--tensorflow/python/kernel_tests/BUILD108
-rw-r--r--tensorflow/python/kernel_tests/accumulate_n_test.py12
-rw-r--r--tensorflow/python/kernel_tests/ackermann_test.py2
-rw-r--r--tensorflow/python/kernel_tests/argmax_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py254
-rw-r--r--tensorflow/python/kernel_tests/as_string_op_test.py12
-rw-r--r--tensorflow/python/kernel_tests/atrous_convolution_test.py2
-rw-r--r--tensorflow/python/kernel_tests/attention_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/barrier_ops_test.py32
-rw-r--r--tensorflow/python/kernel_tests/base64_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/basic_gpu_test.py6
-rw-r--r--tensorflow/python/kernel_tests/batch_gather_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/batchtospace_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/bcast_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/betainc_op_test.py12
-rw-r--r--tensorflow/python/kernel_tests/bincount_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py322
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py140
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py20
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py20
-rw-r--r--tensorflow/python/kernel_tests/broadcast_to_ops_test.py8
-rw-r--r--tensorflow/python/kernel_tests/candidate_sampler_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/cast_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/checkpoint_ops_test.py32
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py28
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py75
-rw-r--r--tensorflow/python/kernel_tests/conditional_accumulator_test.py42
-rw-r--r--tensorflow/python/kernel_tests/confusion_matrix_test.py28
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py52
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py332
-rw-r--r--tensorflow/python/kernel_tests/conv1d_test.py2
-rw-r--r--tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/conv2d_transpose_test.py8
-rw-r--r--tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/conv3d_transpose_test.py10
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_3d_test.py4
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/cross_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_binary_test.py878
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py1198
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_unary_test.py541
-rw-r--r--tensorflow/python/kernel_tests/decode_bmp_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/decode_compressed_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/decode_csv_op_test.py55
-rw-r--r--tensorflow/python/kernel_tests/decode_image_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/decode_png_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/decode_raw_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py8
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py16
-rw-r--r--tensorflow/python/kernel_tests/division_future_test.py2
-rw-r--r--tensorflow/python/kernel_tests/division_past_test.py2
-rw-r--r--tensorflow/python/kernel_tests/duplicate_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/dynamic_stitch_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py60
-rw-r--r--tensorflow/python/kernel_tests/extract_image_patches_grad_test.py2
-rw-r--r--tensorflow/python/kernel_tests/extract_volume_patches_op_test.py131
-rw-r--r--tensorflow/python/kernel_tests/fft_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/fifo_queue_test.py128
-rw-r--r--tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py18
-rw-r--r--tensorflow/python/kernel_tests/fractional_max_pool_op_test.py18
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/gather_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/gradient_correctness_test.py8
-rw-r--r--tensorflow/python/kernel_tests/identity_n_op_py_test.py8
-rw-r--r--tensorflow/python/kernel_tests/identity_op_py_test.py10
-rw-r--r--tensorflow/python/kernel_tests/in_topk_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/init_ops_test.py36
-rw-r--r--tensorflow/python/kernel_tests/inplace_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/io_ops_test.py8
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD16
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py (renamed from tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py)54
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py73
-rw-r--r--tensorflow/python/kernel_tests/linalg_grad_test.py4
-rw-r--r--tensorflow/python/kernel_tests/linalg_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/listdiff_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_logging_level_test.py70
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py280
-rw-r--r--tensorflow/python/kernel_tests/lookup_ops_test.py222
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py216
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py16
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/matrix_inverse_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/metrics_test.py258
-rw-r--r--tensorflow/python/kernel_tests/numerics_test.py8
-rw-r--r--tensorflow/python/kernel_tests/pad_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/padding_fifo_queue_test.py124
-rw-r--r--tensorflow/python/kernel_tests/parse_single_example_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py18
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py40
-rw-r--r--tensorflow/python/kernel_tests/priority_queue_test.py20
-rw-r--r--tensorflow/python/kernel_tests/random/random_ops_test.py9
-rw-r--r--tensorflow/python/kernel_tests/reader_ops_test.py36
-rw-r--r--tensorflow/python/kernel_tests/record_input_test.py14
-rw-r--r--tensorflow/python/kernel_tests/reduce_join_op_test.py16
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test.py36
-rw-r--r--tensorflow/python/kernel_tests/reduction_ops_test_big.py12
-rw-r--r--tensorflow/python/kernel_tests/regex_full_match_op_test.py12
-rw-r--r--tensorflow/python/kernel_tests/regex_replace_op_test.py39
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py36
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/reverse_sequence_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py8
-rw-r--r--tensorflow/python/kernel_tests/scalar_test.py2
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py32
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py14
-rw-r--r--tensorflow/python/kernel_tests/session_ops_test.py32
-rw-r--r--tensorflow/python/kernel_tests/sets_test.py10
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py34
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/softmax_op_test.py25
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/softsign_op_test.py9
-rw-r--r--tensorflow/python/kernel_tests/spacetobatch_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py44
-rw-r--r--tensorflow/python/kernel_tests/sparse_cross_op_test.py34
-rw-r--r--tensorflow/python/kernel_tests/sparse_matmul_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py16
-rw-r--r--tensorflow/python/kernel_tests/sparsemask_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/string_format_op_test.py384
-rw-r--r--tensorflow/python/kernel_tests/string_join_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/string_length_op_test.py29
-rw-r--r--tensorflow/python/kernel_tests/string_split_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/string_strip_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/string_to_number_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/substr_op_test.py163
-rw-r--r--tensorflow/python/kernel_tests/summary_audio_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/summary_image_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/summary_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/summary_tensor_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/tensor_array_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/unique_op_test.py20
-rw-r--r--tensorflow/python/kernel_tests/unstack_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/variable_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py60
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py58
-rw-r--r--tensorflow/python/kernel_tests/weights_broadcast_test.py8
-rw-r--r--tensorflow/python/kernel_tests/while_v2_test.py276
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py10
-rw-r--r--tensorflow/python/ops/array_ops.py61
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py6
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py6
-rw-r--r--tensorflow/python/ops/control_flow_ops.py11
-rw-r--r--tensorflow/python/ops/ctc_ops.py6
-rw-r--r--tensorflow/python/ops/distributions/beta.py9
-rw-r--r--tensorflow/python/ops/distributions/bijector_impl.py45
-rw-r--r--tensorflow/python/ops/distributions/categorical.py4
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py9
-rw-r--r--tensorflow/python/ops/distributions/distribution.py113
-rw-r--r--tensorflow/python/ops/distributions/gamma.py9
-rw-r--r--tensorflow/python/ops/distributions/kullback_leibler.py4
-rw-r--r--tensorflow/python/ops/distributions/normal.py9
-rw-r--r--tensorflow/python/ops/distributions/student_t.py14
-rw-r--r--tensorflow/python/ops/distributions/util.py12
-rw-r--r--tensorflow/python/ops/embedding_ops.py8
-rw-r--r--tensorflow/python/ops/functional_ops.py40
-rw-r--r--tensorflow/python/ops/gradients_impl.py58
-rw-r--r--tensorflow/python/ops/gradients_test.py39
-rw-r--r--tensorflow/python/ops/image_ops_impl.py54
-rw-r--r--tensorflow/python/ops/image_ops_test.py12
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_addition.py (renamed from tensorflow/contrib/linalg/python/ops/linear_operator_addition.py)0
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_circulant.py18
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_test_util.py16
-rw-r--r--tensorflow/python/ops/logging_ops.py260
-rw-r--r--tensorflow/python/ops/lookup_ops.py40
-rw-r--r--tensorflow/python/ops/losses/util_test.py6
-rw-r--r--tensorflow/python/ops/math_ops.py25
-rw-r--r--tensorflow/python/ops/nn_ops.py34
-rw-r--r--tensorflow/python/ops/parallel_for/BUILD2
-rw-r--r--tensorflow/python/ops/parallel_for/control_flow_ops_test.py192
-rw-r--r--tensorflow/python/ops/parallel_for/gradients.py2
-rw-r--r--tensorflow/python/ops/parallel_for/gradients_test.py26
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py98
-rw-r--r--tensorflow/python/ops/parsing_ops.py10
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py74
-rw-r--r--tensorflow/python/ops/rnn.py4
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py10
-rw-r--r--tensorflow/python/ops/string_ops.py102
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py1
-rw-r--r--tensorflow/python/ops/while_v2.py580
-rw-r--r--tensorflow/python/platform/gfile.py2
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py42
-rw-r--r--tensorflow/python/profiler/pprof_profiler_test.py2
-rw-r--r--tensorflow/python/pywrap_tensorflow.py2
-rwxr-xr-xtensorflow/python/pywrap_tfe.i5
-rw-r--r--tensorflow/python/saved_model/README.md15
-rw-r--r--tensorflow/python/summary/writer/event_file_writer.py2
-rw-r--r--tensorflow/python/summary/writer/writer_test.py4
-rw-r--r--tensorflow/python/tools/BUILD1
-rw-r--r--tensorflow/python/tools/api/generator/create_python_api.py1
-rw-r--r--tensorflow/python/tools/optimize_for_inference_test.py16
-rw-r--r--tensorflow/python/tools/saved_model_cli.py9
-rw-r--r--tensorflow/python/training/adagrad.py2
-rw-r--r--tensorflow/python/training/distribute.py4
-rw-r--r--tensorflow/python/training/ftrl_test.py4
-rw-r--r--tensorflow/python/training/gradient_descent_test.py10
-rw-r--r--tensorflow/python/training/learning_rate_decay_v2_test.py2
-rw-r--r--tensorflow/python/training/monitored_session.py24
-rw-r--r--tensorflow/python/training/optimizer.py13
-rw-r--r--tensorflow/python/training/quantize_training.i7
-rw-r--r--tensorflow/python/training/saver.py8
-rw-r--r--tensorflow/python/training/saver_test.py32
-rw-r--r--tensorflow/python/training/warm_starting_util_test.py8
-rw-r--r--tensorflow/python/util/memory.py45
-rw-r--r--tensorflow/python/util/nest.py22
-rw-r--r--tensorflow/python/util/nest_test.py6
-rw-r--r--tensorflow/python/util/tf_inspect.py5
-rw-r--r--tensorflow/python/util/util.i27
-rw-r--r--tensorflow/requirements.txt2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc129
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h16
-rw-r--r--tensorflow/stream_executor/device_description.h6
-rw-r--r--tensorflow/stream_executor/dnn.h4
-rw-r--r--tensorflow/stream_executor/lib/array_slice.h8
-rw-r--r--tensorflow/stream_executor/lib/inlined_vector.h4
-rw-r--r--tensorflow/stream_executor/lib/strcat.h6
-rw-r--r--tensorflow/stream_executor/lib/stringpiece.h5
-rw-r--r--tensorflow/stream_executor/plugin_registry.h2
-rw-r--r--tensorflow/stream_executor/stream.cc38
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc24
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h18
-rw-r--r--tensorflow/tensorflow.bzl9
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt9
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt9
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt18
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt9
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt9
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt22
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt4
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py2
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le4
-rw-r--r--tensorflow/tools/ci_build/README.md2
-rwxr-xr-xtensorflow/tools/ci_build/builds/run_pip_tests.sh6
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh2
-rwxr-xr-xtensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh77
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh4
-rwxr-xr-xtensorflow/tools/ci_build/linux/cpu/run_mkl.sh2
-rw-r--r--tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh2
-rw-r--r--tensorflow/tools/compatibility/testdata/test_file_v0_11.py2
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_v2.py8
-rw-r--r--tensorflow/tools/dist_test/README.md2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel2
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu2
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl2
-rw-r--r--tensorflow/tools/docker/jupyter_notebook_config.py2
-rwxr-xr-xtensorflow/tools/docker/parameterized_docker_build.sh2
-rw-r--r--tensorflow/tools/dockerfiles/README.md6
-rw-r--r--tensorflow/tools/docs/BUILD3
-rw-r--r--tensorflow/tools/docs/generate_lib.py14
-rw-r--r--tensorflow/tools/docs/parser.py2
-rw-r--r--tensorflow/tools/pip_package/BUILD23
-rwxr-xr-xtensorflow/tools/pip_package/build_pip_package.sh3
-rw-r--r--tensorflow/tools/pip_package/setup.py6
-rw-r--r--tensorflow/tools/test/check_futures_test.py3
-rwxr-xr-xtensorflow/workspace.bzl59
1644 files changed, 58966 insertions, 20887 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 386e0096ff..3610eea42a 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -564,6 +564,7 @@ tf_cc_shared_object(
"$(location //tensorflow/c:version_script.lds)",
],
}),
+ visibility = ["//visibility:public"],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_experimental",
@@ -588,6 +589,7 @@ tf_cc_shared_object(
"$(location //tensorflow:tf_version_script.lds)",
],
}),
+ visibility = ["//visibility:public"],
deps = [
"//tensorflow:tf_exported_symbols.lds",
"//tensorflow:tf_version_script.lds",
@@ -608,6 +610,55 @@ exports_files(
],
)
+genrule(
+ name = "install_headers",
+ srcs = [
+ "//tensorflow/c:headers",
+ "//tensorflow/c/eager:headers",
+ "//tensorflow/cc:headers",
+ "//tensorflow/core:headers",
+ ],
+ outs = ["include"],
+ cmd = """
+ mkdir $@
+ for f in $(SRCS); do
+ d="$${f%/*}"
+ d="$${d#bazel-out*genfiles/}"
+ d="$${d#*external/eigen_archive/}"
+
+ if [[ $${d} == *local_config_* ]]; then
+ continue
+ fi
+
+ if [[ $${d} == external* ]]; then
+ extname="$${d#*external/}"
+ extname="$${extname%%/*}"
+ if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then
+ continue
+ fi
+ fi
+
+ mkdir -p "$@/$${d}"
+ cp "$${f}" "$@/$${d}/"
+ done
+ """,
+ tags = ["manual"],
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "root_init_gen",
+ srcs = select({
+ "api_version_2": [":tf_python_api_gen_v2"],
+ "//conditions:default": [":tf_python_api_gen_v1"],
+ }),
+ outs = ["__init__.py"],
+ cmd = select({
+ "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
+ "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
+ }),
+)
+
gen_api_init_files(
name = "tf_python_api_gen_v1",
srcs = ["api_template.__init__.py"],
@@ -629,19 +680,6 @@ gen_api_init_files(
root_init_template = "api_template.__init__.py",
)
-genrule(
- name = "root_init_gen",
- srcs = select({
- "api_version_2": [":tf_python_api_gen_v2"],
- "//conditions:default": [":tf_python_api_gen_v1"],
- }),
- outs = ["__init__.py"],
- cmd = select({
- "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)",
- "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)",
- }),
-)
-
py_library(
name = "tensorflow_py",
srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"],
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 53a72b8443..2de740e145 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -14,9 +14,9 @@
# ==============================================================================
"""Bring in all of the public TensorFlow interface into this module."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
+from __future__ import absolute_import as _absolute_import
+from __future__ import division as _division
+from __future__ import print_function as _print_function
import os as _os
@@ -41,6 +41,11 @@ except (ImportError, AttributeError):
from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top
contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib')
del LazyLoader
+# The templated code that replaces the placeholder above sometimes
+# sets the __all__ variable. If it does, we have to be sure to add
+# "contrib".
+if '__all__' in vars():
+ vars()['__all__'].append('contrib')
from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
app.flags = flags # pylint: disable=undefined-variable
@@ -51,10 +56,6 @@ _tf_api_dir = _os.path.dirname(_os.path.dirname(app.__file__)) # pylint: disabl
if _tf_api_dir not in __path__:
__path__.append(_tf_api_dir)
-del absolute_import
-del division
-del print_function
-
# These symbols appear because we import the python package which
# in turn imports from tensorflow.core and tensorflow.python. They
# must come from this module. So python adds these symbols for the
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 173bbea596..79811ceae5 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index c046bd66cd..3bcc62cf2d 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -8704,3 +8705,53 @@ TFE_TensorHandle* TFE_DequeueVariantTensor(TF_Session* session, int tensor_id,
return createTFEDequeue(ctx, TF_VARIANT, queue, status);
}
+
+static void CheckOk(TF_Status* status) {
+ CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
+}
+
+void TFE_TensorHandlePrintDebugString(TFE_TensorHandle* handle) {
+ auto* status = TF_NewStatus();
+ TF_Tensor* t = TFE_TensorHandleResolve(handle, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ tensorflow::Tensor dst;
+ TF_CHECK_OK(TF_TensorToTensor(t, &dst));
+ LOG(INFO) << dst.DebugString();
+
+ TF_DeleteTensor(t);
+ TF_DeleteStatus(status);
+}
+
+TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx) {
+ // Intentionally LOG into INFO below for ease of debugging.
+ VLOG(1) << "TFE_RunConstOp called";
+
+ auto* status = TF_NewStatus();
+ auto* op = TFE_NewOp(ctx, "Const", status);
+ CheckOk(status);
+ TFE_OpSetAttrType(op, "dtype", TF_FLOAT);
+
+ auto* tensor =
+ TF_AllocateTensor(TF_FLOAT, /*shape.data()*/ nullptr, /*shape.size()*/ 0,
+ TF_DataTypeSize(TF_FLOAT) * 1);
+ auto* ptr = reinterpret_cast<char*>(TF_TensorData(tensor));
+ *reinterpret_cast<float*>(ptr) = 17.0;
+
+ TFE_OpSetAttrTensor(op, "value", tensor, status);
+ CheckOk(status);
+ TF_DeleteTensor(tensor);
+ VLOG(1) << "New op created";
+
+ TFE_TensorHandle* retval;
+ int num_retvals = 1;
+ TFE_Execute(op, &retval, &num_retvals, status);
+ CheckOk(status);
+ CHECK_EQ(num_retvals, 1);
+ VLOG(1) << "Op executed";
+
+ TFE_DeleteOp(op);
+ TF_DeleteStatus(status);
+
+ return retval;
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 522c91f67e..a3ca847d96 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -174,6 +174,15 @@ TF_CAPI_EXPORT extern void TFE_EnqueueVariantTensor(TF_Session* session,
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_DequeueVariantTensor(
TF_Session* session, int tensor_id, TF_Status* status);
+// Prints `handle` in a human readable format to standard output for debugging.
+TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
+ TFE_TensorHandle* handle);
+
+// Returns a const scalar tensor.
+// Caller owns both the input and the output tensor handles.
+// TODO: Remove this API with hard-coded tensor computation.
+TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_RunConstOp(TFE_Context* ctx);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index a2c5a42c11..f68f8a3e90 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/strings/base64.h"
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 37be52f57d..3ee31a6a7a 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -68,7 +68,10 @@ tf_cuda_library(
tf_cuda_library(
name = "c_api_internal",
hdrs = ["c_api_internal.h"],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
":c_api",
"//tensorflow/c:c_api",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 349d9bcd7c..0bf3d9542b 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -375,6 +375,17 @@ int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
return result;
}
+int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return -1;
+ }
+ tensorflow::int64 result;
+ status->status = h->handle->NumElements(&result);
+ return result;
+}
+
int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
@@ -567,6 +578,13 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
+void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
+ TF_Status* status) {
+ tensorflow::Tensor t;
+ status->status = TF_TensorToTensor(tensor, &t);
+ if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
+}
+
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 337447eec9..6323f8a053 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -163,6 +163,8 @@ TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h);
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h,
TF_Status* status);
+TF_CAPI_EXPORT extern int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h,
+ TF_Status* status);
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
@@ -311,6 +313,11 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op,
const char* attr_name,
const TFE_Op* value);
+TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op,
+ const char* attr_name,
+ TF_Tensor* tensor,
+ TF_Status* status);
+
TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op,
const char* attr_name,
const void* const* values,
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index ce038a4b57..41b5b8ff36 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -29,15 +29,8 @@ limitations under the License.
namespace tensorflow {
namespace eager {
-// Information about a tensor.
-struct TapeTensor {
- int64 id; // Expected to be unique in the lifetime of this process.
- DataType dtype;
- TensorShape shape;
-};
-
// Represents an entry in the tape.
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct OpTapeEntry {
string op_type;
std::vector<TapeTensor> output_tensor_info;
@@ -57,8 +50,8 @@ struct OpTapeEntry {
using TensorTape = gtl::FlatMap<int64, int64>;
// Map from operation-id to tape entry.
-template <typename BackwardFunction>
-using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
+template <typename BackwardFunction, typename TapeTensor>
+using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
@@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
// specialization, which is blocked by quite a few things needing to loop back
// into python now.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class VSpace {
public:
virtual ~VSpace() {}
@@ -93,10 +86,10 @@ class VSpace {
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
// Returns a tensor of the right shape and dtype filled with zeros.
- virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
// Returns a Tensor which is filled with ones and like the input.
- virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
// Calls the passed-in backward function.
virtual Status CallBackwardFunction(
@@ -114,7 +107,7 @@ class VSpace {
// Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class GradientTape {
public:
// If `persistent` is true, GradientTape will not eagerly delete backward
@@ -134,7 +127,7 @@ class GradientTape {
void Watch(int64 tensor_id);
void RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -146,17 +139,18 @@ class GradientTape {
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
- Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<int64> source_tensor_id,
- gtl::ArraySlice<Gradient*> output_gradients,
- std::vector<Gradient*>* result);
+ Status ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<int64> source_tensor_id,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ std::vector<Gradient*>* result);
bool IsPersistent() const { return persistent_; }
private:
TensorTape tensor_tape_;
- OpTape<BackwardFunction> op_tape_;
+ OpTape<BackwardFunction, TapeTensor> op_tape_;
int64 next_op_id_{0};
// Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) {
}
}
-template <typename Gradient, typename BackwardFunction>
-bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
CHECK_EQ(tensor_ids.size(), dtypes.size());
@@ -201,14 +195,15 @@ bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
return false;
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
+ int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -229,16 +224,18 @@ void GradientTape<Gradient, BackwardFunction>::RecordOperation(
for (const TapeTensor& o : output_tensors) {
// Note: the tensor can have already been watched and hence be in the tape,
// so we cannot check that we're inserting it here.
- tensor_tape_[o.id] = op_id;
- tensor_usage_[o.id] = 1;
+ tensor_tape_[o.GetID()] = op_id;
+ tensor_usage_[o.GetID()] = 1;
tensors.push_back(o);
}
- op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
- op_type, tensors, ids, backward_function, backward_function_deleter};
+ op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
+ op_type, std::move(tensors), ids, backward_function,
+ backward_function_deleter};
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
+ int64 tensor_id) {
auto it = tensor_usage_.find(tensor_id);
if (it == tensor_usage_.end()) {
return;
@@ -261,7 +258,7 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
auto op_it = op_tape_.find(op_id);
CHECK(op_it != op_tape_.end());
for (const auto& output : op_it->second.output_tensor_info) {
- if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
+ if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
// Found a usage for an output, so cannot delete the op.
return;
}
@@ -304,9 +301,9 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
namespace {
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct BackpropInitialState {
- OpTape<BackwardFunction> op_tape;
+ OpTape<BackwardFunction, TapeTensor> op_tape;
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
@@ -322,17 +319,17 @@ struct BackpropInitialState {
// If `persistent_tape` is false, op_tape is cleared and backwards functions
// not needed for gradient computation are deleted. Backwards functions that
// are needed, are copied and returned in BackpropInitialState.
-template <typename BackwardFunction>
-BackpropInitialState<BackwardFunction> PrepareBackprop(
+template <typename BackwardFunction, typename TapeTensor>
+BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
- OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
- bool persistent_tape) {
+ OpTape<BackwardFunction, TapeTensor>* op_tape,
+ const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
tensor_stack.push_back(t);
}
- BackpropInitialState<BackwardFunction> result;
+ BackpropInitialState<BackwardFunction, TapeTensor> result;
while (!tensor_stack.empty()) {
int64 tensor_id = tensor_stack.back();
tensor_stack.pop_back();
@@ -383,9 +380,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
return result;
}
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
std::vector<int64> InitialStack(
- const OpTape<BackwardFunction>& op_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
@@ -396,13 +393,13 @@ std::vector<int64> InitialStack(
return result;
}
-template <typename Gradient, typename BackwardFunction>
-Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<Gradient*> output_gradients,
- const TensorTape& tensor_tape,
- const OpTape<BackwardFunction>& op_tape,
- gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status InitialGradients(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
+ gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (output_gradients.empty() || output_gradients[i] == nullptr) {
@@ -416,11 +413,10 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
+ if (op_it->second.output_tensor_info[j].GetID() == id) {
found = true;
(*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
+ vspace.Ones(op_it->second.output_tensor_info[j]));
break;
}
}
@@ -440,6 +436,18 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
return Status::OK();
}
+// TODO(agarwal): use an automatic mechanism for handling None arguments to
+// gradient functions.
+//
+// Some gradient functions can accept None arguments for gradients. The
+// following maps the operation name to the indices at which the corresponding
+// gradient function can accept None values. e.g. FusedBatchNorm outputs 5
+// values and hence receives 5 gradient values during backprop. However the
+// gradient function uses only the first of those values and ignores the rest.
+// The entry, "FusedBatchNorm": [1, 2, 3, 4], indicates that only the gradient
+// corresponding to index 0 is used, and the gradient values at indices 1-4 are
+// ignored (and hence can be None). The backprop algorithm can then leverage
+// this by not constructing zeros to pass for those indices.
gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
static auto* const m = new gtl::FlatMap<string, gtl::FlatSet<int>>({
{"SoftmaxCrossEntropyWithLogits", {1}},
@@ -457,16 +465,16 @@ gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
constexpr int kMinAggregateCount = 4;
constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
-template <typename Gradient, typename BackwardFunction>
-Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
- const VSpace<Gradient, BackwardFunction>& vspace,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
- BackpropInitialState<BackwardFunction> state = PrepareBackprop(
+ BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
@@ -510,7 +518,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
out_gradients.reserve(trace.output_tensor_info.size());
bool any_gradient_nonzero = false;
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
- const int64 id = trace.output_tensor_info[i].id;
+ const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
@@ -519,9 +527,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
- out_gradients.push_back(
- vspace.Zeros(trace.output_tensor_info[i].shape,
- trace.output_tensor_info[i].dtype));
+ out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
}
} else {
any_gradient_nonzero = true;
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 8486b585c8..247236b760 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@@ -135,9 +135,8 @@ std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
return result;
}
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status) {
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status) {
tensorflow::CppShapeInferenceResult::HandleData handle_data;
if (!handle_data.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 4bcb5bde62..5cce84020b 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -54,16 +54,17 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
void ExtendSession(TF_Session* session, TF_Status* status);
// Returns the serialized CppShapeInferenceResult::HandleData proto for
-// `output` if its a resource tensor, or otherwise returns the empty string.
-std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+// `output` if its a resource or variant tensor, or otherwise returns the empty
+// string.
+std::string GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
// Sets `output` based on `proto`, which should be a serialized
-// CppShapeInferenceResult::HandleData proto.
+// CppShapeInferenceResult::HandleData proto. `output` should be a resource
+// or variant tensor.
// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
// because I couldn't get SWIG to work otherwise.
-void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status);
+void SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
+ size_t proto_len, TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index e99d15f85d..9d2208d84d 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -10,11 +10,12 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
- "tf_cc_test",
+ "cc_library_with_android_deps",
"tf_cc_binary",
+ "tf_cc_test",
"tf_copts",
"tf_gen_op_wrappers_cc",
- "cc_library_with_android_deps",
+ "transitive_hdrs",
)
cc_library(
@@ -717,3 +718,26 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
+
+transitive_hdrs(
+ name = "headers",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":cc_ops",
+ ":client_session",
+ ":coordinator",
+ ":gradient_checker",
+ ":gradients",
+ ":ops",
+ ":queue_runner",
+ ":remote_fused_graph_ops",
+ ":scope",
+ "//tensorflow/cc/profiler",
+ "//tensorflow/cc/saved_model:constants",
+ "//tensorflow/cc/saved_model:loader",
+ "//tensorflow/cc/saved_model:reader",
+ "//tensorflow/cc/saved_model:signature_constants",
+ "//tensorflow/cc/saved_model:tag_constants",
+ "//tensorflow/cc/tools:freeze_saved_model",
+ ],
+)
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 8d94f5495c..10fa33ab5e 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
+ ":test_graph_tftop_k_test",
":tfcompile_test",
],
)
@@ -42,6 +43,7 @@ py_binary(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python:training",
@@ -66,6 +68,7 @@ genrule(
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tfsplits.pb",
+ "test_graph_tftop_k.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
# GPUs which might be present. This is important because builds may run
@@ -208,6 +211,17 @@ tf_library(
],
)
+tf_library(
+ name = "test_graph_tftop_k",
+ testonly = 1,
+ config = "test_graph_tftop_k.config.pbtxt",
+ cpp_class = "TopKComp",
+ graph = "test_graph_tftop_k.pb",
+ tags = [
+ "manual",
+ ],
+)
+
tf_cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
@@ -226,11 +240,13 @@ tf_cc_test(
":test_graph_tfmatmulandadd",
":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
+ ":test_graph_tftop_k",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 9ec7df163b..de135d7a23 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.training import saver as saver_lib
@@ -142,6 +143,12 @@ def tfsplits(_):
array_ops.identity(y, name='result')
+def tftop_k(_):
+ x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
+ output = nn_ops.top_k(x, 2, name='values')
+ array_ops.identity(output[1], name='indices')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -163,6 +170,7 @@ def main(_):
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tftop_k, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
new file mode 100644
index 0000000000..6b4ac2d7cb
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
@@ -0,0 +1,13 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "x" }
+ shape {
+ dim { size: 5 }
+ }
+}
+fetch {
+ id { node_name: "values" }
+}
+fetch {
+ id { node_name: "indices" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index dd2b151098..f10852c785 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -29,10 +29,12 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -447,6 +449,30 @@ TEST(TFCompileTest, Splits) {
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, TopK) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ TopKComp fn;
+
+ fn.set_thread_pool(&device);
+ // x = [4, 1, 4, 4, 3]
+ fn.arg0(0) = 4;
+ fn.arg0(1) = 1;
+ fn.arg0(2) = 4;
+ fn.arg0(3) = 4;
+ fn.arg0(4) = 3;
+
+ EXPECT_TRUE(fn.Run());
+ EXPECT_EQ(fn.error_msg(), "");
+ const int32 expected_values[] = {4, 4};
+ const int32 expected_indices[] = {0, 2};
+ EXPECT_EQ(expected_values[0], fn.result0(0));
+ EXPECT_EQ(expected_values[1], fn.result0(1));
+ EXPECT_EQ(expected_indices[0], fn.result1(0));
+ EXPECT_EQ(expected_indices[1], fn.result1(1));
+}
+
TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
@@ -543,7 +569,13 @@ TEST(TFCompileTest, HloProfiling) {
string hlo_profile_as_string =
xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
/*clock_rate_ghz=*/1.0);
- VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
+ VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
+
+ // Strip away identifier details from the profile string to avoid this test
+ // being a change detector for xla internals. Identifiers such as '%dot.0.7'
+ // just become '%dot'.
+ RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
+ VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
std::vector<string> hlo_profile_lines =
absl::StrSplit(hlo_profile_as_string, '\n');
@@ -551,16 +583,14 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
- "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
- "%arg1.0.1)");
+ "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
auto add_profile_line = HasSubstr(
- "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
- "%arg1.0.1)");
+ "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
auto tuple_profile_line = HasSubstr(
- "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
- "%dot.0.4, f32[2,2]{1,0} %add.0.6)");
- auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
- auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
+ "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
+ "f32[2,2]{1,0} %add)");
+ auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
+ auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
EXPECT_THAT(hlo_profile_lines,
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 792b7fe14a..859c84bb91 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -273,6 +273,7 @@ def tf_library(
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 352f63bc98..4e184729ef 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -26,6 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@@ -50,7 +51,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
],
@@ -62,7 +63,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda([
":jit_compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin",
]),
@@ -76,7 +77,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/jit/legacy_flags:xla_device_flags",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
@@ -94,7 +95,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
@@ -111,7 +112,7 @@ cc_library(
deps = [
":jit_compilation_passes",
":xla_device",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
@@ -265,6 +266,7 @@ cc_library(
srcs = ["jit_compilation_pass_registration.cc"],
deps = [
":compilation_passes",
+ "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration",
"//tensorflow/core:core_cpu_internal",
],
alwayslink = 1,
@@ -279,7 +281,7 @@ cc_library(
deps = [
":common",
":compilation_passes",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@@ -340,7 +342,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -358,7 +360,7 @@ tf_cc_test(
cc_library(
name = "compilation_passes",
srcs = [
- "build_xla_launch_ops_pass.cc",
+ "build_xla_ops_pass.cc",
"deadness_analysis.cc",
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
@@ -368,7 +370,7 @@ cc_library(
"partially_decluster_pass.cc",
],
hdrs = [
- "build_xla_launch_ops_pass.h",
+ "build_xla_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"encapsulate_xla_computations_pass.h",
@@ -458,7 +460,7 @@ tf_cc_test(
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -492,7 +494,7 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/cc:sendrecv_ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:test_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
@@ -523,7 +525,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
- "//tensorflow/compiler/jit/kernels:xla_launch_op",
+ "//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/core:core_cpu",
@@ -598,6 +600,44 @@ tf_cuda_cc_test(
],
)
+cc_library(
+ name = "node_matchers",
+ testonly = True,
+ srcs = ["node_matchers.cc"],
+ hdrs = ["node_matchers.h"],
+ deps = [
+ "//tensorflow/cc:ops",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "node_matchers_test",
+ srcs = ["node_matchers_test.cc"],
+ deps = [
+ ":node_matchers",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_custom_op_py_library(
+ name = "xla_ops_py",
+ kernels = ["//tensorflow/compiler/jit/ops:xla_ops"],
+ visibility = [
+ ":friends",
+ ],
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
+)
+
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
deleted file mode 100644
index b17ff589e2..0000000000
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/framework/graph_def_util.h"
-#include "tensorflow/core/framework/node_def_builder.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/public/version.h"
-
-namespace tensorflow {
-
-static Status BuildLaunchNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes,
- Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("XlaLaunch");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
-}
-
-static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
- VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
-
- int num_constant_args, num_resource_args;
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args));
- TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args));
-
- if (num_constant_args < 0 || num_resource_args < 0 ||
- num_constant_args + num_resource_args > node->num_inputs()) {
- return errors::InvalidArgument(
- "Invalid number of constant/resource arguments to XLA kernel.");
- }
- const int num_nonconst_args =
- node->num_inputs() - num_constant_args - num_resource_args;
-
- DataTypeVector const_dtypes(node->input_types().begin(),
- node->input_types().begin() + num_constant_args);
- DataTypeVector arg_dtypes(
- node->input_types().begin() + num_constant_args,
- node->input_types().begin() + num_constant_args + num_nonconst_args);
-
- // Build a XlaLaunch operator to execute the function body.
- Node* launch_node;
- TF_RETURN_IF_ERROR(BuildLaunchNode(
- graph->NewName(node->name()), node->type_string(), node->def().attr(),
- node->requested_device(), const_dtypes, num_resource_args, arg_dtypes,
- node->output_types(), graph, &launch_node));
- launch_node->set_assigned_device_name(node->assigned_device_name());
-
- // Copy incoming edges to the launch node.
- for (const Edge* edge : node->in_edges()) {
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(edge->src(), launch_node);
- } else {
- graph->AddEdge(edge->src(), edge->src_output(), launch_node,
- edge->dst_input());
- }
- }
-
- // Copy outgoing edges to the launch node.
- std::vector<const Edge*> out_edges(node->out_edges().begin(),
- node->out_edges().end());
- for (const Edge* edge : out_edges) {
- Node* dst = edge->dst();
- int src_output = edge->src_output();
- int dst_input = edge->dst_input();
- graph->RemoveEdge(edge);
-
- if (edge->IsControlEdge()) {
- graph->AddControlEdge(launch_node, dst);
- } else {
- graph->AddEdge(launch_node, src_output, dst, dst_input);
- }
- }
- graph->RemoveNode(node);
-
- return Status::OK();
-}
-
-Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
- Graph* graph = options.graph->get();
-
- for (Node* n : graph->op_nodes()) {
- // In all cases, only try to compile computational nodes.
- if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
- continue;
- }
-
- // Only compile nodes that are marked for compilation by the
- // compilation-marking pass (via 'attr_name').
- if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
- }
- }
-
- if (VLOG_IS_ON(1)) {
- dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph,
- options.flib_def);
- }
- return Status::OK();
-}
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
new file mode 100644
index 0000000000..13a518d0e8
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -0,0 +1,189 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+static Status BuildXlaCompileNode(
+ const string& nodename, const string& function_name,
+ const AttrValueMap& function_attr, const string& device_name,
+ const DataTypeVector& constant_dtypes, int num_resources,
+ const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaCompile");
+ def.set_device(device_name);
+ AddNodeAttr("Tconstants", constant_dtypes, &def);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Nresources", num_resources, &def);
+ NameAttrList function;
+ function.set_name(function_name);
+ *function.mutable_attr() = function_attr;
+ AddNodeAttr("function", function, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status BuildXlaRunNode(const string& nodename, const string& device_name,
+ const DataTypeVector& arg_dtypes,
+ const DataTypeVector& result_dtypes, Graph* graph,
+ Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaRun");
+ def.set_device(device_name);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Tresults", result_dtypes, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status GetXlaAttrs(Node* node, int* num_constant_args,
+ int* num_resource_args, DataTypeVector* const_dtypes,
+ DataTypeVector* arg_dtypes) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+
+ if (*num_constant_args < 0 || *num_resource_args < 0 ||
+ *num_constant_args + *num_resource_args > node->num_inputs()) {
+ return errors::InvalidArgument(
+ "Invalid number of constant/resource arguments to XLA kernel.");
+ }
+
+ const int num_nonconst_args =
+ node->num_inputs() - *num_constant_args - *num_resource_args;
+
+ const DataTypeVector& input_types = node->input_types();
+ std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
+ std::back_inserter(*const_dtypes));
+ std::copy(input_types.begin() + *num_constant_args,
+ input_types.begin() + *num_constant_args + num_nonconst_args,
+ std::back_inserter(*arg_dtypes));
+ return Status::OK();
+}
+
+static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node,
+ int prefix_to_ignore) {
+ for (const Edge* edge : old_node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ g->AddControlEdge(edge->src(), new_node);
+ } else if (edge->dst_input() >= prefix_to_ignore) {
+ g->AddEdge(edge->src(), edge->src_output(), new_node,
+ edge->dst_input() - prefix_to_ignore);
+ }
+ }
+}
+
+static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ Node* dst = edge->dst();
+ int src_output = edge->src_output();
+ int dst_input = edge->dst_input();
+ g->RemoveEdge(edge);
+
+ if (edge->IsControlEdge()) {
+ g->AddControlEdge(new_node, dst);
+ } else {
+ g->AddEdge(new_node, src_output, dst, dst_input);
+ }
+ }
+}
+
+static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
+ int num_constant_args, num_resource_args;
+ DataTypeVector const_dtypes;
+ DataTypeVector arg_dtypes;
+
+ TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
+ &const_dtypes, &arg_dtypes));
+
+ Node *compile_node, *run_node;
+
+ TF_RETURN_IF_ERROR(BuildXlaCompileNode(
+ n->name(), n->type_string(), n->def().attr(), n->requested_device(),
+ const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+
+ DataTypeVector arg_dtypes_with_resources = arg_dtypes;
+ for (int i = 0; i < num_resource_args; i++) {
+ arg_dtypes_with_resources.push_back(DT_RESOURCE);
+ }
+
+ TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
+ arg_dtypes_with_resources,
+ n->output_types(), g, &run_node));
+
+ compile_node->set_assigned_device_name(n->assigned_device_name());
+ run_node->set_assigned_device_name(n->assigned_device_name());
+
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node,
+ /*prefix_to_ignore=*/0);
+ CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node,
+ /*prefix_to_ignore=*/num_constant_args);
+
+ // The compilation_key output.
+ g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args);
+
+ MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+
+Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+
+ for (Node* n : graph->op_nodes()) {
+ // In all cases, only try to compile computational nodes.
+ if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
+ continue;
+ }
+
+ // Only compile nodes that are marked for compilation by the
+ // compilation-marking pass (via 'attr_name').
+ if (IsXlaCompiledKernel(*n)) {
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+ }
+ }
+
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def);
+ }
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h
index 1dfea93f02..1dd38fa951 100644
--- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.h
@@ -13,19 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
-#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
+// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and
+// executes (using XLA) TF function calls marked with "_XlaCompiledKernel".
+class BuildXlaOpsPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_
diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc
index 56b034a30b..6f1ff85f24 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
index c8bb4dc114..99e9dfd598 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h
@@ -1,23 +1,22 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
// Rewrites computations generated by the xla.compile() Python code into
// XlaLaunch nodes.
//
// xla.compile() does two main things:
-// a) marks operators that make up a XLA computation with the attribute
+// a) marks operators that make up an XLA computation with the attribute
// _xla_compile_id=XYZ, where XYZ is a unique key.
// b) adds XlaClusterOutput nodes to represent outputs of the computation.
// These nodes are not marked with the _xla_compile_id attribute.
@@ -29,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/platform/env.h"
-namespace tensorflow {
+ namespace tensorflow {
// Encapsulates nodes marked with the _xla_compile_id attribute into
// XlaLaunch operators.
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 315fcb2fa7..085c0e5adb 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
+#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -22,12 +22,24 @@ limitations under the License.
namespace tensorflow {
+// PRE_PLACEMENT passes:
+
// EncapsulateXlaComputationsPass rewrites computations generated by the
// xla.compile() Python code into XlaLaunch nodes.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 26,
EncapsulateXlaComputationsPass);
-// The following POST_REWRITE passes support auto-clustering to enable XLA.
+// from
+// third_party/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
+// FunctionalizeControlFlowPass: 27
+//
+// This pass looks at the graph and all associated FunctionDefs, and turns
+// traditional control flow structure (Switch/Merge/etc.) into functional
+// control flow structure (XlaIf/XlaWhile). Following passes must
+// handle those FunctionDef correctly.
+
+// POST_REWRITE_FOR_EXEC passes that support auto-clustering to enable XLA:
+
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
@@ -43,6 +55,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
// Must run after EncapsulateSubgraphsPass.
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
- BuildXlaLaunchOpsPass);
+ BuildXlaOpsPass);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 253a5d2547..0839f1cb3d 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -7,9 +7,9 @@ package(
)
cc_library(
- name = "xla_launch_op",
- srcs = ["xla_launch_op.cc"],
- hdrs = ["xla_launch_op.h"],
+ name = "xla_ops",
+ srcs = ["xla_ops.cc"],
+ hdrs = ["xla_ops.h"],
deps = [
"//tensorflow/compiler/jit:common",
"//tensorflow/compiler/jit:xla_compilation_cache",
@@ -26,6 +26,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/memory",
],
alwayslink = 1,
)
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
deleted file mode 100644
index b6f2f632f7..0000000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ /dev/null
@@ -1,276 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
-
-#include "tensorflow/compiler/jit/defs.h"
-#include "tensorflow/compiler/jit/xla_launch_util.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
-#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/variable_ops.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function)
- : OpKernel(ctx),
- constants_(constants),
- resources_(resources),
- device_type_(ctx->device_type()),
- function_(function) {
- if (device_type_ == DeviceType(DEVICE_CPU)) {
- platform_id_ = se::host::kHostPlatformId;
- } else if (device_type_ == DeviceType(DEVICE_GPU)) {
- platform_id_ = ctx->device()
- ->tensorflow_gpu_device_info()
- ->stream->parent()
- ->platform()
- ->id();
- } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) {
- use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams();
- platform_id_ = xla_device_metadata_->platform()->id();
- }
-}
-
-Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache) {
- if (xla_device_metadata_) {
- *cache = new XlaCompilationCache(xla_device_metadata_->client(),
- xla_device_metadata_->jit_device_type());
- return Status::OK();
- }
-
- auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_);
- if (!platform.ok()) {
- return platform.status();
- }
- xla::LocalClientOptions client_options;
- client_options.set_platform(platform.ValueOrDie());
- client_options.set_intra_op_parallelism_threads(
- ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
- auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
- if (!client.ok()) {
- return client.status();
- }
- const XlaOpRegistry::DeviceRegistration* registration;
- if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(),
- &registration)) {
- return errors::InvalidArgument("No JIT device registered for ",
- device_type_.type());
- }
- *cache = new XlaCompilationCache(
- client.ValueOrDie(), DeviceType(registration->compilation_device_name));
- return Status::OK();
-}
-
-void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
- VLOG(1) << "XlaLocalLaunchOpBase::Compute "
- << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
- // We store information about the JIT-compiled XLA computation
- // in the ResourceMgr.
- ResourceMgr* rm = ctx->resource_manager();
- OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
-
- se::Stream* stream =
- ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
-
- XlaCompilationCache* cache;
- OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>(
- rm->default_container(), "xla_cache", &cache,
- [this, ctx](XlaCompilationCache** cache) {
- return BuildCompilationCache(ctx, cache);
- }));
- // Hold the reference to the JIT during evaluation. (We could probably
- // free it sooner because the ResourceMgr will retain a reference, but
- // this is more obviously correct.)
- core::ScopedUnref cache_ref(cache);
-
- std::map<int, OptionalTensor> variables =
- SnapshotResourceVariables(ctx, resources_);
-
- xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
-
- XlaAllocator local_xla_allocator(client->backend().platform(),
- ctx->device()->GetAllocator({}));
- xla::DeviceMemoryAllocator* xla_allocator;
- // If we are on an XlaDevice, use the underlying XLA platform's allocator
- // directly. We could use the StreamExecutor's allocator which may
- // theoretically be more correct, but XLA returns a nice OOM message in a
- // Status and StreamExecutor does not.
- //
- // Importantly we can't use ctx->device()->GetAllocator() as the allocator
- // (which local_xla_allocator above uses) as on an XlaDevice, this is a
- // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a
- // real allocator to allocate real buffers.
- if (xla_device_metadata_) {
- xla_allocator = client->backend().memory_allocator();
- } else {
- xla_allocator = &local_xla_allocator;
- }
-
- XlaCompiler::Options options;
- options.client = client;
- if (ctx->op_device_context() != nullptr) {
- options.device_ordinal =
- ctx->op_device_context()->stream()->parent()->device_ordinal();
- }
- options.device_type = cache->device_type();
- options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
- options.graph_def_version = ctx->function_library()->graph_def_version();
- options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
- options.device_allocator = xla_allocator;
- if (xla_device_metadata_) {
- options.shape_representation_fn =
- xla_device_metadata_->shape_representation_fn();
- }
-
- const XlaCompiler::CompilationResult* kernel;
- xla::LocalExecutable* executable;
-
- std::map<int, Tensor> constant_args;
- for (int i : constants_) {
- constant_args.insert({i, ctx->input(i)});
- }
- XlaCompiler::CompileOptions compile_options;
- compile_options.is_entry_computation = true;
- // If we resolve constants we never emit them on the device, meaning that if
- // they are needed by a following computation the host has to transfer
- // them. Not resolving constants is expected to be faster than resolving
- // constants.
- compile_options.resolve_compile_time_constants = true;
- // Optimization: where possible, have the computation return a naked array
- // rather than a one-element tuple.
- compile_options.always_return_tuple = false;
-
- OP_REQUIRES_OK(
- ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, compile_options));
-
- VLOG(1) << "Executing XLA Computation...";
-
- XlaComputationLaunchContext launch_context(
- client, xla_allocator,
- /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr,
- use_multiple_streams_);
- launch_context.PopulateInputs(ctx, kernel, variables);
-
- // Execute the computation.
- VLOG(2) << "Executing computation.";
- xla::ExecutableRunOptions run_options;
- run_options.set_stream(stream);
- run_options.set_allocator(xla_allocator);
- run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
- run_options.set_rng_seed(GetXLARandomSeed());
- Env* env = Env::Default();
- auto start_time = env->NowMicros();
-
- auto run_result = executable->Run(launch_context.arguments(), run_options);
- OP_REQUIRES(ctx, run_result.ok(), run_result.status());
-
- auto elapsed = env->NowMicros() - start_time;
- VLOG(2) << "Elapsed time: " << elapsed << "us";
-
- OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
- ctx, kernel, run_result.ConsumeValueOrDie()));
- VLOG(1) << "Done";
-}
-
-namespace {
-
-// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
-// in error case, it returns RET instead of void.
-#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
- do { \
- ::tensorflow::Status _s(__VA_ARGS__); \
- if (!TF_PREDICT_TRUE(_s.ok())) { \
- (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
- return RET; \
- } \
- } while (0)
-
-// Helper static functions to construct parameters for
-// XlaLocalLaunchBase constructor from OpKernelConstruction.
-std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
- std::vector<int> constants(constant_types.size());
- std::iota(constants.begin(), constants.end(), 0);
- return constants;
-}
-
-std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
- DataTypeVector constant_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Tconstants", &constant_types));
-
- DataTypeVector arg_types;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Targs", &arg_types));
-
- int num_resources;
- OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
- ctx->GetAttr("Nresources", &num_resources));
-
- std::vector<int> resources(num_resources);
- std::iota(resources.begin(), resources.end(),
- constant_types.size() + arg_types.size());
- return resources;
-}
-
-NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
- const NameAttrList* func;
- OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
- return *func;
-}
-
-#undef OP_REQUIRES_OK_RETURN
-} // namespace
-
-XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
- : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
- FunctionAttr(ctx)) {}
-
-XlaLocalLaunchOp::~XlaLocalLaunchOp() {
- VLOG(1) << "XlaLocalLaunchOp destroyed";
-}
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
-
-REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
- .Device(DEVICE_GPU)
- .HostMemory("constants")
- .HostMemory("resources"),
- XlaLocalLaunchOp);
-
-} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
deleted file mode 100644
index e0f10e9817..0000000000
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ /dev/null
@@ -1,87 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
-
-#include "tensorflow/compiler/jit/xla_compilation_cache.h"
-#include "tensorflow/compiler/jit/xla_device.h"
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/util/stream_executor_util.h"
-
-namespace tensorflow {
-
-// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
-// The only difference is that it does not require arguments to follow
-// the "constants, then regular args, then resources" order.
-// It takes vectors of constant and resource arguments explicitly.
-// It does not have corresponding OpDef because it is never present
-// in the GraphDef.
-// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
-// this kernel when asked to create a kernel for an XLA-compiled function.
-class XlaLocalLaunchBase : public OpKernel {
- public:
- XlaLocalLaunchBase(OpKernelConstruction* ctx,
- const std::vector<int>& constants,
- const std::vector<int>& resources,
- const NameAttrList& function);
- XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
- XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
- ~XlaLocalLaunchBase() override = default;
-
- void Compute(OpKernelContext* ctx) override;
-
- protected:
- // Builds a XlaCompilationCache class suitable for the current device.
- Status BuildCompilationCache(OpKernelContext* ctx,
- XlaCompilationCache** cache);
-
- // Indexes of compile-time constant inputs
- std::vector<int> constants_;
- // Indexes of resource inputs
- std::vector<int> resources_;
-
- DeviceType device_type_;
- NameAttrList function_;
- se::Platform::Id platform_id_ = nullptr;
- bool use_multiple_streams_ = false;
- const XlaDevice::Metadata* xla_device_metadata_ = nullptr;
-};
-
-// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
-// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
-// responsible for handling interactions with the TensorFlow executor.
-// Once all inputs are present, and their shapes are known, the op can
-// use a 'XlaCompilationCache' to compile and execute code which is specific
-// to the shapes of input Tensors.
-// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
-// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
-// memory.
-class XlaLocalLaunchOp : public XlaLocalLaunchBase {
- public:
- explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
- ~XlaLocalLaunchOp() override;
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
new file mode 100644
index 0000000000..a85006eb03
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -0,0 +1,499 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status PlatformInfoFromContext(OpKernelConstruction* ctx,
+ XlaPlatformInfo* result) {
+ DeviceType device_type = ctx->device_type();
+ se::Platform::Id platform_id = nullptr;
+ const XlaDevice::Metadata* xla_device_metadata = nullptr;
+ std::unique_ptr<XlaAllocator> xla_allocator;
+ xla::DeviceMemoryAllocator* device_allocator = nullptr;
+
+ if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
+ platform_id = se::host::kHostPlatformId;
+ } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
+ platform_id = ctx->device()
+ ->tensorflow_gpu_device_info()
+ ->stream->parent()
+ ->platform()
+ ->id();
+ } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
+ // If we are on an XlaDevice, use the underlying XLA platform's allocator
+ // directly. We could use the StreamExecutor's allocator which may
+ // theoretically be more correct, but XLA returns a nice OOM message in a
+ // Status and StreamExecutor does not.
+ //
+ // Importantly we can't use ctx->device()->GetAllocator() as the allocator
+ // (which xla_allocator above uses) as on an XlaDevice, this is a dummy
+ // allocator that returns XlaTensor objects. The XlaCompiler needs a real
+ // allocator to allocate real buffers.
+
+ platform_id = xla_device_metadata->platform()->id();
+ device_allocator =
+ xla_device_metadata->client()->backend().memory_allocator();
+ }
+
+ if (!device_allocator) {
+ TF_ASSIGN_OR_RETURN(se::Platform* const platform,
+ se::MultiPlatformManager::PlatformWithId(platform_id));
+ xla_allocator = absl::make_unique<XlaAllocator>(
+ platform, ctx->device()->GetAllocator({}));
+ }
+
+ *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
+ std::move(xla_allocator), device_allocator);
+
+ return Status::OK();
+}
+
+// A closure describing how to run a compiled version of a TensorFlow function.
+//
+// It may seem unusual to stick the resource variable snapshots in this class.
+// This is necessary: we need to use the snapshots observed by the compiler as
+// the initial values for the resource variables (and cannot snapshot them again
+// during execution) because otherwise we risk observing a different snapshot
+// with shapes different from what we compiled for.
+class XlaExecutableClosure {
+ public:
+ explicit XlaExecutableClosure(
+ xla::LocalClient* client, xla::LocalExecutable* executable,
+ const XlaCompiler::CompilationResult* compilation_result,
+ std::map<int, OptionalTensor> resource_var_snapshots,
+ int num_constant_args)
+ : client_(client),
+ executable_(executable),
+ compilation_result_(compilation_result),
+ resource_var_snapshots_(std::move(resource_var_snapshots)),
+ num_constant_args_(num_constant_args) {}
+
+ XlaExecutableClosure(XlaExecutableClosure&&) = default;
+ XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default;
+
+ xla::LocalClient* client() const { return client_; }
+ xla::LocalExecutable* executable() const { return executable_; }
+ const XlaCompiler::CompilationResult* compilation_result() const {
+ return compilation_result_;
+ }
+ const std::map<int, OptionalTensor>& resource_var_snapshots() const {
+ return resource_var_snapshots_;
+ }
+ int num_constant_args() const { return num_constant_args_; }
+
+ private:
+ xla::LocalClient* client_;
+ xla::LocalExecutable* executable_;
+ const XlaCompiler::CompilationResult* compilation_result_;
+ std::map<int, OptionalTensor> resource_var_snapshots_;
+ int num_constant_args_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure);
+};
+
+// This maintains a mapping from a globally unique ID to XlaExecutableClosure
+// instances.
+class XlaExecutableClosureStore {
+ public:
+ XlaExecutableClosureStore() : key_counter_(0) {}
+
+ using KeyT = string;
+
+ KeyT Produce(XlaExecutableClosure result) {
+ mutex_lock l(mutex_);
+ KeyT key = absl::StrCat(key_counter_++);
+ bool insert_successful = closures_.emplace(key, std::move(result)).second;
+ DCHECK(insert_successful);
+ (void)insert_successful;
+ return key;
+ }
+
+ XlaExecutableClosure Consume(const KeyT& key) {
+ mutex_lock l(mutex_);
+ auto it = closures_.find(key);
+ DCHECK(it != closures_.end());
+ XlaExecutableClosure value = std::move(it->second);
+ closures_.erase(it);
+ return value;
+ }
+
+ static XlaExecutableClosureStore* Global() {
+ static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore;
+ return instance;
+ }
+
+ private:
+ mutex mutex_;
+ int64 key_counter_ GUARDED_BY(mutex_);
+ gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
+};
+
+} // namespace
+
+XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function)
+ : OpKernel(ctx),
+ constants_(constants),
+ resources_(resources),
+ function_(function) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+static Status BuildCompilationCache(OpKernelContext* ctx,
+ const XlaPlatformInfo& platform_info,
+ XlaCompilationCache** cache) {
+ if (platform_info.xla_device_metadata()) {
+ *cache = new XlaCompilationCache(
+ platform_info.xla_device_metadata()->client(),
+ platform_info.xla_device_metadata()->jit_device_type());
+ return Status::OK();
+ }
+
+ auto platform =
+ se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
+ if (!platform.ok()) {
+ return platform.status();
+ }
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform.ValueOrDie());
+ client_options.set_intra_op_parallelism_threads(
+ ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
+ auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
+ if (!client.ok()) {
+ return client.status();
+ }
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
+ &registration)) {
+ return errors::InvalidArgument("No JIT device registered for ",
+ platform_info.device_type().type());
+ }
+ *cache = new XlaCompilationCache(
+ client.ValueOrDie(), DeviceType(registration->compilation_device_name));
+ return Status::OK();
+}
+
+static Status CompileToLocalExecutable(
+ OpKernelContext* ctx, const NameAttrList& function,
+ const XlaPlatformInfo& platform_info, absl::Span<const int> resources,
+ absl::Span<const int> constants, xla::LocalClient** client,
+ std::map<int, OptionalTensor>* variables,
+ const XlaCompiler::CompilationResult** kernel,
+ xla::LocalExecutable** executable) {
+ // We store information about the JIT-compiled XLA computation
+ // in the ResourceMgr.
+ ResourceMgr* rm = ctx->resource_manager();
+ if (!rm) {
+ return errors::Internal("No resource manager.");
+ }
+
+ XlaCompilationCache* cache;
+ TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
+ rm->default_container(), "xla_cache", &cache,
+ [&](XlaCompilationCache** cache) {
+ return BuildCompilationCache(ctx, platform_info, cache);
+ }));
+ // Hold the reference to the JIT during evaluation. (We could probably
+ // free it sooner because the ResourceMgr will retain a reference, but
+ // this is more obviously correct.)
+ core::ScopedUnref cache_ref(cache);
+
+ *variables = SnapshotResourceVariables(ctx, resources);
+ *client = static_cast<xla::LocalClient*>(cache->client());
+
+ XlaCompiler::Options options;
+ options.client = *client;
+ if (ctx->op_device_context() != nullptr) {
+ options.device_ordinal =
+ ctx->op_device_context()->stream()->parent()->device_ordinal();
+ }
+ options.device_type = cache->device_type();
+ options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ options.graph_def_version = ctx->function_library()->graph_def_version();
+ options.allow_cpu_custom_calls =
+ (platform_info.platform_id() == se::host::kHostPlatformId);
+ options.device_allocator = platform_info.allocator();
+ if (platform_info.xla_device_metadata()) {
+ options.shape_representation_fn =
+ platform_info.xla_device_metadata()->shape_representation_fn();
+ }
+
+ std::map<int, Tensor> constant_args;
+ for (int i : constants) {
+ constant_args.insert({i, ctx->input(i)});
+ }
+ XlaCompiler::CompileOptions compile_options;
+ compile_options.is_entry_computation = true;
+ // If we resolve constants we never emit them on the device, meaning that if
+ // they are needed by a following computation the host has to transfer
+ // them. Not resolving constants is expected to be faster than resolving
+ // constants.
+ compile_options.resolve_compile_time_constants = true;
+ // Optimization: where possible, have the computation return a naked array
+ // rather than a one-element tuple.
+ compile_options.always_return_tuple = false;
+
+ return cache->Compile(options, function, constant_args, *variables, ctx,
+ kernel, executable, compile_options);
+}
+
+void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOpBase::Compute "
+ << Canonicalize(function_.name(), AttrSlice(&function_.attr()));
+
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+
+ VLOG(1) << "Executing XLA Computation...";
+
+ XlaComputationLaunchContext launch_context(
+ client, platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ platform_info_.UseMultipleStreams());
+ launch_context.PopulateInputs(ctx, kernel, variables,
+ /*missing_ctx_input_prefix=*/0);
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result = executable->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+ ctx, kernel, run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/0));
+ VLOG(1) << "Done";
+}
+
+namespace {
+
+// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that
+// in error case, it returns RET instead of void.
+#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \
+ do { \
+ ::tensorflow::Status _s(__VA_ARGS__); \
+ if (!TF_PREDICT_TRUE(_s.ok())) { \
+ (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \
+ return RET; \
+ } \
+ } while (0)
+
+// Helper static functions to construct parameters for
+// XlaLocalLaunchBase constructor from OpKernelConstruction.
+std::vector<int> ConstantsVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+ std::vector<int> constants(constant_types.size());
+ std::iota(constants.begin(), constants.end(), 0);
+ return constants;
+}
+
+std::vector<int> ResourcesVector(OpKernelConstruction* ctx) {
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Tconstants", &constant_types));
+
+ DataTypeVector arg_types;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Targs", &arg_types));
+
+ int num_resources;
+ OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(),
+ ctx->GetAttr("Nresources", &num_resources));
+
+ std::vector<int> resources(num_resources);
+ std::iota(resources.begin(), resources.end(),
+ constant_types.size() + arg_types.size());
+ return resources;
+}
+
+NameAttrList FunctionAttr(OpKernelConstruction* ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func));
+ return *func;
+}
+
+#undef OP_REQUIRES_OK_RETURN
+} // namespace
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx),
+ FunctionAttr(ctx)) {}
+
+XlaLocalLaunchOp::~XlaLocalLaunchOp() {
+ VLOG(1) << "XlaLocalLaunchOp destroyed";
+}
+
+XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx),
+ constants_(ConstantsVector(ctx)),
+ resources_(ResourcesVector(ctx)),
+ function_(FunctionAttr(ctx)) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaCompileOp::Compute(OpKernelContext* ctx) {
+ xla::LocalClient* client;
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ std::map<int, OptionalTensor> variables;
+
+ OP_REQUIRES_OK(
+ ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_,
+ constants_, &client, &variables, &kernel,
+ &executable));
+
+ // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
+ // if it didn't have to compile the cluster because of a compilation-cache
+ // hit. This is because we at least need new snapshots of the resource
+ // variables.
+ XlaExecutableClosureStore::KeyT key =
+ XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
+ client, executable, kernel, std::move(variables), constants_.size()));
+
+ Allocator* cpu_allocator = [&] {
+ AllocatorAttributes host_alloc_attrs;
+ host_alloc_attrs.set_gpu_compatible(true);
+ host_alloc_attrs.set_on_host(true);
+ return ctx->device()->GetAllocator(host_alloc_attrs);
+ }();
+
+ Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
+ compilation_key.flat<string>()(0) = key;
+
+ Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
+ compilation_successful.flat<bool>()(0) = true;
+
+ ctx->set_output(0, compilation_key);
+ ctx->set_output(1, compilation_successful);
+}
+
+XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_));
+}
+
+void XlaRunOp::Compute(OpKernelContext* ctx) {
+ Tensor key_tensor = ctx->input(ctx->num_inputs() - 1);
+ const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0);
+
+ XlaExecutableClosure closure =
+ XlaExecutableClosureStore::Global()->Consume(key);
+
+ XlaComputationLaunchContext launch_context(
+ closure.client(), platform_info_.allocator(),
+ /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
+ /*use_multiple_streams=*/platform_info_.UseMultipleStreams());
+
+ // We're missing the must-be-constant inputs, tell `PopulateInputs`
+ // about this. We don't actually need these inputs because they've
+ // already been baked into the compiled kernel.
+ launch_context.PopulateInputs(
+ ctx, closure.compilation_result(), closure.resource_var_snapshots(),
+ /*missing_ctx_input_prefix=*/closure.num_constant_args());
+
+ se::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(platform_info_.allocator());
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ run_options.set_rng_seed(GetXLARandomSeed());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+
+ auto run_result =
+ closure.executable()->Run(launch_context.arguments(), run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
+
+ OP_REQUIRES_OK(
+ ctx,
+ launch_context.PopulateOutputs(
+ ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/closure.num_constant_args()));
+}
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("XlaLaunch")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaCompile")
+ .Device(DEVICE_GPU)
+ .HostMemory("constants")
+ .HostMemory("resources"),
+ XlaCompileOp);
+
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp);
+REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h
new file mode 100644
index 0000000000..489d26eb30
--- /dev/null
+++ b/tensorflow/compiler/jit/kernels/xla_ops.h
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+// Holds some information about the platform on which an
+// XlaLaunch/_XlaCompile/_XlaRun op must run on.
+class XlaPlatformInfo {
+ public:
+ XlaPlatformInfo() : device_type_("") {}
+ explicit XlaPlatformInfo(const DeviceType device_type,
+ se::Platform::Id platform_id,
+ const XlaDevice::Metadata* xla_device_metadata,
+ std::unique_ptr<XlaAllocator> xla_allocator,
+ xla::DeviceMemoryAllocator* device_allocator)
+ : device_type_(device_type),
+ platform_id_(platform_id),
+ xla_device_metadata_(xla_device_metadata),
+ xla_allocator_(std::move(xla_allocator)),
+ device_allocator_(device_allocator) {
+ CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr));
+ }
+
+ XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
+
+ bool UseMultipleStreams() const {
+ return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
+ }
+
+ xla::DeviceMemoryAllocator* allocator() const {
+ return device_allocator_ ? device_allocator_ : xla_allocator_.get();
+ }
+ DeviceType device_type() const { return device_type_; }
+
+ // This is equal to xla_device_metadata()->platform()->id() if
+ // xla_device_metadata() is not nullptr.
+ se::Platform::Id platform_id() const { return platform_id_; }
+
+ // This may be null if the op this XlaPlatformInfo is for was not placed on an
+ // XLA device.
+ const XlaDevice::Metadata* xla_device_metadata() const {
+ return xla_device_metadata_;
+ }
+ bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
+
+ private:
+ DeviceType device_type_;
+ se::Platform::Id platform_id_;
+
+ // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
+ // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
+ // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
+ const XlaDevice::Metadata* xla_device_metadata_;
+
+ // If the op associated with this XlaPlatformInfo is placed on an XLA device
+ // then device_allocator_ is the xla::Backend's memory allocator and
+ // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device
+ // then device_allocator_ is null and xla_allocator_ points to an appropriate
+ // XlaAllocator instance.
+ std::unique_ptr<XlaAllocator> xla_allocator_;
+ xla::DeviceMemoryAllocator* device_allocator_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
+};
+
+// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
+// The only difference is that it does not require arguments to follow
+// the "constants, then regular args, then resources" order.
+// It takes vectors of constant and resource arguments explicitly.
+// It does not have corresponding OpDef because it is never present
+// in the GraphDef.
+// Currently, it is used by eager runtime. FunctionLibraryRuntime creates
+// this kernel when asked to create a kernel for an XLA-compiled function.
+class XlaLocalLaunchBase : public OpKernel {
+ public:
+ XlaLocalLaunchBase(OpKernelConstruction* ctx,
+ const std::vector<int>& constants,
+ const std::vector<int>& resources,
+ const NameAttrList& function);
+ XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete;
+ XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete;
+ ~XlaLocalLaunchBase() override = default;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ protected:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+ XlaPlatformInfo platform_info_;
+};
+
+// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
+// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
+// responsible for handling interactions with the TensorFlow executor.
+// Once all inputs are present, and their shapes are known, the op can
+// use a 'XlaCompilationCache' to compile and execute code which is specific
+// to the shapes of input Tensors.
+// XlaLocalLaunchOp uses xla::LocalClient::Compile() and
+// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device
+// memory.
+class XlaLocalLaunchOp : public XlaLocalLaunchBase {
+ public:
+ explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
+ ~XlaLocalLaunchOp() override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
+};
+
+class XlaCompileOp : public OpKernel {
+ public:
+ explicit XlaCompileOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ // Indexes of compile-time constant inputs
+ std::vector<int> constants_;
+ // Indexes of resource inputs
+ std::vector<int> resources_;
+
+ NameAttrList function_;
+
+ XlaPlatformInfo platform_info_;
+};
+
+class XlaRunOp : public OpKernel {
+ public:
+ explicit XlaRunOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ XlaPlatformInfo platform_info_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index e6cc6e52ae..133d982360 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -365,10 +365,13 @@ bool IsXlaFusable(const NodeDef& node) {
return elementwise_ops->count(node.op()) > 0;
}
+// Nodes that XLA can compile are put in `candidates`. Nodes put in
+// `isolated_nodes` must either be unclustered or be put in trivial single-node
+// clusters.
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
- OrderedNodeSet* candidates) {
+ OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
@@ -411,6 +414,8 @@ Status FindCompilationCandidates(
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(node->assigned_device_name(), &device_type));
+ VLOG(4) << "Device type for " << node->name() << ": "
+ << device_type.type_string();
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
// is_compilable_fn has already logged the reason if it returned false.
@@ -439,19 +444,56 @@ Status FindCompilationCandidates(
<< node->type_string();
continue;
}
- if (compile_time_const_nodes[node->id()] &&
- !registration->requires_compilation) {
+ if (compile_time_const_nodes[node->id()]) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(
graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
if (op_def->is_stateful()) {
- // We need to be able to constant fold the nodes in
- // compile_time_const_nodes given constant inputs (required by XLA) and
- // therefore can't auto-cluster stateful ops since these can never be
- // constant folded.
- VLOG(2) << "Rejecting " << node->name()
- << ": must-be-constant stateful op";
- continue;
+ // It is easiest to demonstrate the problem we're trying to solve with
+ // an example. Say we have this graph:
+ //
+ // shape = RandomUniformInt();
+ // reshape = Reshape(input, shape)
+ //
+ // Both RandomUniformInt and Reshape are compilable by XLA so, absent
+ // any other reason, we will try to put both shape and reshape in the
+ // same cluster. However, since XLA only supports statically shaped
+ // values, it will expect to be able to constant fold `shape` to get a
+ // static shape for `reshape`. This is a problem because side-effecting
+ // ops like RandomUniformInt() cannot be constant folded. We fix this
+ // by putting `shape` and `reshape` in different clusters, which results
+ // in us recompiling `reshape`'s cluster for every new value of `shape`,
+ // making `reshape` statically sized within each compilation. We
+ // simplify the solution even further by disallowing operations like
+ // `shape` from being part of *any* non-trivial cluster. They're either
+ // not compiled by XLA altogether or, if assigned to an XLA_* device
+ // with "must compile" semantics, compiled into a trivial single-op
+ // cluster. This approach leaves some room for improvement, and we can
+ // consider implementing a more aggressive data-flow-analysis based
+ // solution in the future if needed.
+ //
+ // One ugly problem we have to contend with: certain sets of ops *have*
+ // to be in the same cluster because values flowing between them have
+ // types that can't be live-in or live-out of a cluster. These ops are:
+ //
+ // - TensorArray ops operating on the same TensorArray instance.
+ // - Stack ops operating on the same Stack instance.
+ //
+ // To work around this we avoid isolating these specific ops. Because
+ // of this concession it is unsound to auto-cluster them because then
+ // we'd create clusters we could not compile (because we can't constant
+ // fold, say, a TensorArrayRead or a StackPopV2). But we don't
+ // auto-cluster these operations today so we're good for now.
+ const XlaResourceOpInfo* op_info =
+ GetResourceOpInfoForOp(node->type_string());
+ bool is_tensor_array_or_stack_op =
+ op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
+ if (!is_tensor_array_or_stack_op) {
+ VLOG(2) << "Isolating " << node->name()
+ << ": must-be-constant stateful op";
+ isolated_nodes->insert(node);
+ // Keep going and execute all the other checks.
+ }
}
}
// We don't auto-cluster functional control flow nodes containing resource
@@ -807,11 +849,12 @@ Status MarkForCompilationPass::RunImpl(
Graph* graph = options.graph->get();
OrderedNodeSet compilation_candidates;
+ gtl::FlatSet<Node*> isolated_nodes;
TF_RETURN_IF_ERROR(FindCompilationCandidates(
*graph, options.flib_def,
(options.session_options != nullptr) ? options.session_options->env
: Env::Default(),
- is_compilable_fn, &compilation_candidates));
+ is_compilable_fn, &compilation_candidates, &isolated_nodes));
if (compilation_candidates.empty()) {
VLOG(2) << "No compilable candidates";
@@ -856,6 +899,11 @@ Status MarkForCompilationPass::RunImpl(
"Found control flow node in clustering worklist: ",
node_from->type_string());
}
+
+ if (isolated_nodes.count(node_from)) {
+ continue;
+ }
+
string from_scope;
string to_scope;
for (int to : cycles.Successors(from)) {
@@ -873,6 +921,9 @@ Status MarkForCompilationPass::RunImpl(
node_to->assigned_device_name()) {
continue;
}
+ if (isolated_nodes.count(node_to)) {
+ continue;
+ }
// Look for an _XlaScope on both nodes. If both nodes have a
// scope and the scopes do not match, do not cluster along this
// edge. This restriction is overridden if the global_jit_level is ON. If
@@ -931,6 +982,11 @@ Status MarkForCompilationPass::RunImpl(
// Names for each cluster.
std::unordered_map<int, string> cluster_names;
+ if (flags->tf_xla_clustering_debug) {
+ dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph,
+ options.flib_def);
+ }
+
// Mark clusters for compilation that:
// * are placed on a device that requires compilation (an XlaDevice),
// * are explicitly marked for compilation (_XlaCompile=true), or
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index c59770a4c8..4f9145b479 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -894,5 +894,71 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) {
EXPECT_EQ(clusters["fn_call"], "");
}
+TEST(XlaCompilationTest, RandomShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output shape_shape =
+ ops::Const(root.WithOpName("test/shape_shape"), {2}, {1});
+ Output shape =
+ ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape,
+ ops::Const(root.WithOpName("test/minval"), 1),
+ ops::Const(root.WithOpName("test/maxval"), 20));
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/shape_rng"], "");
+ EXPECT_NE(clusters["test/reshape"], "");
+ EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]);
+}
+
+TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) {
+ absl::string_view xla_gpu_device =
+ "/job:worker/replica:0/task:0/device:XLA_GPU:0";
+ Scope root = Scope::NewRootScope().ExitOnError();
+ ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1,
+ DT_INT32);
+ Output zero = ops::Const(root.WithOpName("test/zero"), 0);
+ ops::TensorArrayWrite tensor_array_write(
+ root.WithOpName("test/write"), tensor_array.handle, zero,
+ ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow);
+ Output tensor_array_read =
+ ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle,
+ zero, tensor_array_write.flow_out, DT_INT32);
+ Output reshape =
+ ops::Reshape(root.WithOpName("test/reshape"),
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT),
+ tensor_array_read);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(string(xla_gpu_device));
+ }
+ }
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/read"], "");
+ EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index 65669877f7..d56d0f8ccf 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -14,18 +14,35 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
SessionOptions* session_options) {
- // Assign all nodes to the CPU device.
+ // Assign all unassigned nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
}
+ // Call AddDevices to register the XLA devices.
+ //
+ // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
+ // make this more direct, but probably not worth it solely for this test.
+ std::vector<Device*> devices;
+ TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
+
+ auto delete_devices = gtl::MakeCleanup([&] {
+ for (Device* d : devices) {
+ delete d;
+ }
+ });
+
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.session_options = session_options;
diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc
new file mode 100644
index 0000000000..d8ace628e6
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers.cc
@@ -0,0 +1,458 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/node_matchers.h"
+
+#include <utility>
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+
+namespace tensorflow {
+namespace testing {
+namespace matchers {
+namespace {
+
+using impl::NodeMatcherProperties;
+
+string IndentAllButFirstLine(absl::string_view text) {
+ std::vector<std::string> lines = absl::StrSplit(text, '\n');
+ for (int i = 1; i < lines.size(); i++) {
+ lines[i].insert(0, " ");
+ }
+ return absl::StrJoin(lines, "\n");
+}
+
+template <typename T>
+bool CompareTensor(const Tensor& actual, const Tensor& expected,
+ ::testing::MatchResultListener* listener) {
+ if (actual.NumElements() != expected.NumElements()) {
+ if (listener->IsInterested()) {
+ *listener << "\nwas looking for tensor with " << expected.NumElements()
+ << " elements, found tensor with " << actual.NumElements()
+ << " elements";
+ return false;
+ }
+ }
+
+ for (int64 i = 0, e = actual.NumElements(); i < e; i++) {
+ if (actual.flat<T>()(i) != expected.flat<T>()(i)) {
+ *listener << "\nmismatch in constant tensor at index " << i
+ << " expected = " << expected.flat<T>()(i)
+ << " actual = " << actual.flat<T>()(i);
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool MatchAndExplainTensor(const Tensor& tensor, const Tensor& expected_tensor,
+ ::testing::MatchResultListener* listener) {
+ if (tensor.dtype() != expected_tensor.dtype()) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected tensor of type "
+ << DataType_Name(expected_tensor.dtype())
+ << " but found one of type " << DataType_Name(tensor.dtype());
+ return false;
+ }
+ }
+
+ switch (tensor.dtype()) {
+ case DT_FLOAT:
+ return CompareTensor<float>(tensor, expected_tensor, listener);
+ case DT_DOUBLE:
+ return CompareTensor<double>(tensor, expected_tensor, listener);
+ case DT_INT8:
+ return CompareTensor<int8>(tensor, expected_tensor, listener);
+ case DT_INT16:
+ return CompareTensor<int16>(tensor, expected_tensor, listener);
+ case DT_INT32:
+ return CompareTensor<int32>(tensor, expected_tensor, listener);
+ case DT_INT64:
+ return CompareTensor<int64>(tensor, expected_tensor, listener);
+ case DT_UINT8:
+ return CompareTensor<uint8>(tensor, expected_tensor, listener);
+ case DT_UINT16:
+ return CompareTensor<uint16>(tensor, expected_tensor, listener);
+ case DT_UINT32:
+ return CompareTensor<uint32>(tensor, expected_tensor, listener);
+ case DT_UINT64:
+ return CompareTensor<uint64>(tensor, expected_tensor, listener);
+ default:
+ LOG(FATAL) << "Unsupported dtype " // Crash ok: testonly.
+ << DataType_Name(tensor.dtype());
+ }
+}
+
+using Input = std::pair<const Node*, int>;
+
+struct NodeMatcher : public ::testing::MatcherInterface<const Node*> {
+ bool MatchAndExplain(
+ const Node* node,
+ ::testing::MatchResultListener* listener) const override {
+ if (op && node->type_string() != *op) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected op " << *op << " but found "
+ << node->type_string();
+ }
+ return false;
+ }
+
+ if (assigned_device && node->assigned_device_name() != *assigned_device) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected assigned_device " << *assigned_device
+ << " but found \"" << node->assigned_device_name() << "\"";
+ }
+ return false;
+ }
+
+ if (name && node->name() != *name) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected name " << *name << " but found "
+ << node->name();
+ }
+ return false;
+ }
+
+ if (constant_value) {
+ const TensorProto* proto = nullptr;
+ if (!GetNodeAttr(node->def(), "value", &proto).ok()) {
+ if (listener->IsInterested()) {
+ *listener << "\ncould not find \"value\" attribute in node";
+ }
+ return false;
+ }
+
+ Tensor tensor(proto->dtype());
+ if (!tensor.FromProto(*proto)) {
+ if (listener->IsInterested()) {
+ *listener << "\ncould not convert TensorProto in \"value\" attribute "
+ "to Tensor";
+ }
+ return false;
+ }
+
+ if (!MatchAndExplainTensor(/*tensor=*/tensor,
+ /*expected_tensor=*/*constant_value,
+ listener)) {
+ return false;
+ }
+ }
+
+ if (input_matchers) {
+ if (input_matchers->size() != node->num_inputs()) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected " << input_matchers->size()
+ << " inputs but node has " << node->num_inputs();
+ }
+ return false;
+ }
+
+ for (int input_idx = 0, e = input_matchers->size(); input_idx < e;
+ input_idx++) {
+ if (!MatchAndExplainInput(node, input_idx, listener)) {
+ return false;
+ }
+ }
+ }
+
+ std::vector<const Node*> control_deps;
+ for (const Edge* e : node->in_edges()) {
+ if (e->IsControlEdge()) {
+ control_deps.push_back(e->src());
+ }
+ }
+
+ ::testing::StringMatchResultListener inner_listener;
+ if (control_dep_set &&
+ !control_dep_set->MatchAndExplain(control_deps, &inner_listener)) {
+ if (listener->IsInterested()) {
+ string explanation = inner_listener.str();
+ if (!explanation.empty()) {
+ explanation = absl::StrCat(", ", explanation, ",");
+ }
+ *listener << "ctrl_deps" << explanation << " does not match expected: ";
+ control_dep_set->DescribeTo(listener->stream());
+ }
+ return false;
+ }
+ return true;
+ }
+
+ void DescribeTo(::std::ostream* os) const override {
+ std::vector<string> predicates;
+
+ if (name) {
+ predicates.push_back(absl::StrCat("name: ", *name));
+ }
+
+ if (op) {
+ predicates.push_back(absl::StrCat("op: ", *op));
+ }
+
+ if (assigned_device) {
+ predicates.push_back(absl::StrCat("assigned device: ", *assigned_device));
+ }
+
+ bool printed_something = !predicates.empty();
+
+ *os << absl::StrJoin(predicates, ", ");
+
+ if (constant_value) {
+ printed_something = true;
+ *os << "constant value: " << constant_value->DebugString();
+ }
+
+ if (input_matchers) {
+ if (!input_matchers->empty()) {
+ printed_something = true;
+ *os << " with " << (input_matchers->size() == 1 ? "only " : "")
+ << "input" << (input_matchers->size() == 1 ? "" : "s") << " ";
+ }
+
+ if (input_matchers->size() == 1) {
+ ::std::stringstream ss;
+ input_matchers->front().DescribeTo(&ss);
+ printed_something = true;
+ *os << "matching " << ss.str();
+ } else {
+ int edge_idx = 0;
+ for (const ::testing::Matcher<Input>& matcher : (*input_matchers)) {
+ *os << "\n [" << edge_idx << "] matching (";
+ ::std::stringstream ss;
+ matcher.DescribeTo(&ss);
+ printed_something = true;
+ *os << IndentAllButFirstLine(ss.str());
+ *os << ")";
+ edge_idx++;
+ }
+ }
+ }
+
+ if (control_dep_set) {
+ printed_something = true;
+ *os << " and control deps ";
+ control_dep_set->DescribeTo(os);
+ }
+
+ if (!printed_something) {
+ *os << "is any node";
+ }
+ }
+
+ bool MatchAndExplainInput(const Node* node, int input_idx,
+ ::testing::MatchResultListener* listener) const {
+ const Edge* edge;
+ if (!node->input_edge(input_idx, &edge).ok()) {
+ if (listener->IsInterested()) {
+ *listener << "\ncould not find incoming edge for input " << input_idx;
+ }
+ return false;
+ }
+
+ ::testing::StringMatchResultListener inner_listener;
+ Input input = {edge->src(), edge->src_output()};
+ if ((*input_matchers)[input_idx].MatchAndExplain(input, &inner_listener)) {
+ return true;
+ }
+
+ if (listener->IsInterested()) {
+ *listener << "\ninput " << input_idx << " does not match expected:\n";
+ (*input_matchers)[input_idx].DescribeTo(listener->stream());
+ string explanation = inner_listener.str();
+ if (!explanation.empty()) {
+ *listener << ", " << explanation;
+ }
+ }
+ return false;
+ }
+
+ absl::optional<string> op;
+ absl::optional<string> name;
+ absl::optional<string> assigned_device;
+ absl::optional<Tensor> constant_value;
+ absl::optional<std::vector<::testing::Matcher<Input>>> input_matchers;
+ absl::optional<::testing::Matcher<absl::Span<const Node* const>>>
+ control_dep_set;
+};
+
+// Matches a dst and dst_output on an input edge. Today we only use this with
+// dst_output=0 but we will eventually need to support multi-output operations.
+class InputMatcher : public ::testing::MatcherInterface<Input> {
+ public:
+ InputMatcher(::testing::Matcher<const Node*> src_matcher, int src_output)
+ : src_matcher_(std::move(src_matcher)), src_output_(src_output) {}
+
+ bool MatchAndExplain(
+ Input input, ::testing::MatchResultListener* listener) const override {
+ ::testing::StringMatchResultListener inner_listener;
+ if (!src_matcher_.MatchAndExplain(input.first, &inner_listener)) {
+ if (listener->IsInterested()) {
+ *listener << "\nsource does not match expected ";
+ src_matcher_.DescribeTo(listener->stream());
+ string explanation = inner_listener.str();
+ if (!explanation.empty()) {
+ *listener << "\n\t" << explanation;
+ }
+ }
+ return false;
+ }
+ if (input.second != src_output_) {
+ if (listener->IsInterested()) {
+ *listener << "\nexpected output slot to be " << src_output_
+ << " but found " << input.second;
+ }
+ return false;
+ }
+
+ return true;
+ }
+
+ void DescribeTo(::std::ostream* os) const override {
+ if (src_output_) {
+ *os << "output slot: " << src_output_ << ", source: (";
+ }
+
+ src_matcher_.DescribeTo(os);
+
+ if (src_output_) {
+ *os << ")";
+ }
+ }
+
+ private:
+ ::testing::Matcher<const Node*> src_matcher_;
+ int src_output_;
+};
+
+std::vector<::testing::Matcher<Input>> NodeMatchersToInputMatchers(
+ absl::Span<const ::testing::Matcher<const Node*>> node_matchers) {
+ std::vector<::testing::Matcher<Input>> result;
+ absl::c_transform(node_matchers, std::back_inserter(result),
+ [](::testing::Matcher<const Node*> n) {
+ return ::testing::MakeMatcher(new InputMatcher(n, 0));
+ });
+ return result;
+}
+} // namespace
+
+::testing::Matcher<const Node*> impl::NodeWith(
+ absl::Span<const NodeMatcherProperties> props) {
+ NodeMatcher* matcher = new NodeMatcher();
+ for (const NodeMatcherProperties& prop : props) {
+ if (prop.name()) {
+ DCHECK(!matcher->name);
+ matcher->name = prop.name();
+ }
+
+ if (prop.op()) {
+ DCHECK(!matcher->op);
+ matcher->op = prop.op();
+ }
+
+ if (prop.constant_value()) {
+ DCHECK(!matcher->constant_value);
+ matcher->constant_value = prop.constant_value();
+ }
+
+ if (prop.assigned_device()) {
+ DCHECK(!matcher->assigned_device);
+ matcher->assigned_device = prop.assigned_device();
+ }
+
+ if (prop.input_nodes()) {
+ DCHECK(!matcher->input_matchers);
+ matcher->input_matchers =
+ NodeMatchersToInputMatchers(*prop.input_nodes());
+ }
+
+ if (prop.control_deps()) {
+ DCHECK(!matcher->control_dep_set);
+ matcher->control_dep_set =
+ ::testing::UnorderedElementsAreArray(*prop.control_deps());
+ }
+ }
+
+ return ::testing::MakeMatcher(matcher);
+}
+
+impl::NodeMatcherProperties Name(string name) {
+ impl::NodeMatcherProperties props;
+ props.set_name(std::move(name));
+ return props;
+}
+
+// Matches a node with op `op`.
+impl::NodeMatcherProperties Op(string op) {
+ impl::NodeMatcherProperties props;
+ props.set_op(std::move(op));
+ return props;
+}
+
+// Matches a node with assigned device `assigned_device`.
+impl::NodeMatcherProperties AssignedDevice(string assigned_device) {
+ impl::NodeMatcherProperties props;
+ props.set_assigned_device(std::move(assigned_device));
+ return props;
+}
+
+impl::NodeMatcherProperties impl::Inputs(
+ absl::Span<const ::testing::Matcher<const Node*>> inputs) {
+ std::vector<::testing::Matcher<const Node*>> inputs_vector;
+ absl::c_copy(inputs, std::back_inserter(inputs_vector));
+
+ impl::NodeMatcherProperties props;
+ props.set_input_nodes(std::move(inputs_vector));
+ return props;
+}
+
+impl::NodeMatcherProperties impl::CtrlDeps(
+ absl::Span<const ::testing::Matcher<const Node*>> control_deps) {
+ std::vector<::testing::Matcher<const Node*>> control_deps_vector;
+ absl::c_copy(control_deps, std::back_inserter(control_deps_vector));
+
+ impl::NodeMatcherProperties props;
+ props.set_control_deps(std::move(control_deps_vector));
+ return props;
+}
+
+NodeMatcherProperties ConstantValue(
+ const ::tensorflow::Input::Initializer& val) {
+ TF_CHECK_OK(val.status);
+ NodeMatcherProperties props;
+ props.set_constant_value(val.tensor);
+ return props;
+}
+
+::testing::Matcher<const Node*> Const(
+ const ::tensorflow::Input::Initializer& val) {
+ return NodeWith(ConstantValue(val));
+}
+} // namespace matchers
+
+Node* FindNodeByName(Graph* g, absl::string_view name) {
+ for (Node* n : g->nodes()) {
+ if (n->name() == name) {
+ return n;
+ }
+ }
+
+ return nullptr;
+}
+} // namespace testing
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/node_matchers.h b/tensorflow/compiler/jit/node_matchers.h
new file mode 100644
index 0000000000..0437a7e95c
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers.h
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Provides a set of matchers for tensorflow nodes.
+//
+// Example usage:
+//
+// tensorflow::Node* node = ...;
+// EXPECT_THAT(node, NodeWith(Name("name"), Op("op"),
+// Inputs(NodeWith(Name("input")))))
+//
+// Matchable node properties (the expressions that go inside NodeWith(...))
+// are:
+//
+// - Name(string): matches the node name exactly. We will probably need to
+// have this take a string matcher soon in the future.
+//
+// - Op(string): matches the op exactly.
+//
+// - AssignedDevice(string): matches the assigned device exactly.
+//
+// - Inputs(<ordered list>): matches the list of non-control inputs to the node
+// exactly (i.e. does not match a suffix or a prefix).
+//
+// - CtrlDeps(<unordered list>): matches the list of control dependences on the
+// node exactly but in any order.
+//
+// - ConstantValue(tensorflow::Input::Initializer init): matches a Const node
+// with the constant value `init`. Implies Op("Const").
+//
+// Node properties may not be repeated in a single NodeWith(...) matcher.
+// E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since ConstantValue
+// implies Op("Const"), a single NodeWith matcher can't have both
+// ConstantValue(...) and Op(...).
+
+#ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
+#define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
+
+#include <array>
+#include <string>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "absl/types/span.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace testing {
+namespace matchers {
+
+namespace impl {
+
+// -----------------------------------------------------------------------------
+// Implementation details.
+
+// Properties that we match on for a particular Node. If a particular property
+// is nullopt then any value for it is allowed.
+class NodeMatcherProperties {
+ public:
+ using NodeSeqMatcher = std::vector<::testing::Matcher<const Node*>>;
+
+ const absl::optional<string>& name() const { return name_; }
+ const absl::optional<string>& op() const { return op_; }
+ const absl::optional<string>& assigned_device() const {
+ return assigned_device_;
+ }
+ const absl::optional<Tensor>& constant_value() const {
+ return constant_value_;
+ }
+ const absl::optional<NodeSeqMatcher>& input_nodes() const {
+ return input_nodes_;
+ }
+ const absl::optional<NodeSeqMatcher>& control_deps() const {
+ return control_deps_;
+ }
+
+ void set_name(string name) {
+ DCHECK(IsEmpty());
+ name_ = std::move(name);
+ }
+
+ void set_op(string op) {
+ DCHECK(IsEmpty());
+ op_ = std::move(op);
+ }
+
+ void set_assigned_device(string assigned_device) {
+ DCHECK(IsEmpty());
+ assigned_device_ = std::move(assigned_device);
+ }
+
+ void set_constant_value(Tensor constant_value) {
+ DCHECK(IsEmpty());
+ constant_value_ = std::move(constant_value);
+ op_ = "Const";
+ }
+
+ void set_input_nodes(NodeSeqMatcher input_nodes) {
+ DCHECK(IsEmpty());
+ input_nodes_ = std::move(input_nodes);
+ }
+
+ void set_control_deps(NodeSeqMatcher control_deps) {
+ DCHECK(IsEmpty());
+ control_deps_ = std::move(control_deps);
+ }
+
+ bool IsEmpty() const {
+ return !name().has_value() && !op().has_value() &&
+ !input_nodes().has_value() && !control_deps().has_value();
+ }
+
+ private:
+ absl::optional<string> name_;
+ absl::optional<string> op_;
+ absl::optional<string> assigned_device_;
+ absl::optional<Tensor> constant_value_;
+ absl::optional<NodeSeqMatcher> input_nodes_;
+ absl::optional<NodeSeqMatcher> control_deps_;
+};
+
+::testing::Matcher<const Node*> NodeWith(
+ absl::Span<const NodeMatcherProperties> props);
+
+impl::NodeMatcherProperties Inputs(
+ absl::Span<const ::testing::Matcher<const Node*>> inputs);
+
+impl::NodeMatcherProperties CtrlDeps(
+ absl::Span<const ::testing::Matcher<const Node*>> control_deps);
+} // namespace impl
+
+// -----------------------------------------------------------------------------
+// Public interface.
+
+// Matches a node with name `name`.
+impl::NodeMatcherProperties Name(string name);
+
+// Matches a node with op `op`.
+impl::NodeMatcherProperties Op(string op);
+
+// Matches a node with assigned device `assigned_device`.
+impl::NodeMatcherProperties AssignedDevice(string assigned_device);
+
+// Matches a node with inputs `inputs`.
+//
+// `inputs` are ordered; `inputs`[i] must match input i.
+template <typename... Ts>
+impl::NodeMatcherProperties Inputs(Ts... inputs) {
+ return impl::Inputs({inputs...});
+}
+
+// Matches a node with control dependences `control_deps`.
+//
+// `control_deps` are unordered and will match the control deps of a node in any
+// order.
+template <typename... Ts>
+impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) {
+ return impl::CtrlDeps({control_deps...});
+}
+
+// Matches a constant node with value `val`.
+impl::NodeMatcherProperties ConstantValue(
+ const ::tensorflow::Input::Initializer& val);
+
+// The main gmock matcher. See file comment for example usage.
+template <typename... Ts>
+::testing::Matcher<const Node*> NodeWith(Ts... args) {
+ std::array<impl::NodeMatcherProperties, sizeof...(Ts)> array = {args...};
+ return impl::NodeWith(array);
+}
+
+::testing::Matcher<const Node*> Const(
+ const ::tensorflow::Input::Initializer& val);
+} // namespace matchers
+
+// If `g` has a node named `name` returns it, otherwise returns null.
+Node* FindNodeByName(Graph* g, absl::string_view name);
+} // namespace testing
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_
diff --git a/tensorflow/compiler/jit/node_matchers_test.cc b/tensorflow/compiler/jit/node_matchers_test.cc
new file mode 100644
index 0000000000..93a8994307
--- /dev/null
+++ b/tensorflow/compiler/jit/node_matchers_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/node_matchers.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+
+namespace tensorflow {
+namespace testing {
+namespace {
+
+using ::testing::_;
+
+using testing::matchers::AssignedDevice;
+using testing::matchers::ConstantValue;
+using testing::matchers::CtrlDeps;
+using testing::matchers::Inputs;
+using testing::matchers::Name;
+using testing::matchers::NodeWith;
+using testing::matchers::Op;
+
+template <typename M, typename T>
+string Explain(const T& t, const M& m) {
+ ::testing::StringMatchResultListener listener;
+ EXPECT_THAT(t, ::testing::Not(m)); // For the error message.
+ EXPECT_FALSE(m.MatchAndExplain(t, &listener));
+ return listener.str();
+}
+
+TEST(NodeMatchers, CheckAgainstConstant) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output placeholder =
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
+
+ EXPECT_THAT(placeholder.node(), NodeWith(Op("Placeholder")));
+ EXPECT_THAT(placeholder.node(), NodeWith(Name("placeholder")));
+ EXPECT_THAT(placeholder.node(),
+ NodeWith(Op("Placeholder"), Name("placeholder")));
+ EXPECT_THAT(placeholder.node(),
+ NodeWith(Name("placeholder"), Op("Placeholder")));
+ EXPECT_THAT(placeholder.node(), NodeWith(Inputs()));
+ EXPECT_THAT(placeholder.node(),
+ NodeWith(Op("Placeholder"), Name("placeholder"), Inputs()));
+
+ EXPECT_EQ(Explain(placeholder.node(), NodeWith(Op("Add"))),
+ "\nexpected op Add but found Placeholder");
+ EXPECT_EQ(Explain(placeholder.node(), NodeWith(Name("add"))),
+ "\nexpected name add but found placeholder");
+ EXPECT_EQ(Explain(placeholder.node(), NodeWith(Inputs(NodeWith()))),
+ "\nexpected 1 inputs but node has 0");
+}
+
+TEST(NodeMatchers, CheckAgainstBinary) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output placeholder_a =
+ ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+ Output placeholder_b =
+ ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+ Output add = ops::Add(root.WithOpName("add"), placeholder_a, placeholder_b);
+
+ EXPECT_THAT(add.node(), NodeWith(Op("Add"), Name("add"),
+ Inputs(NodeWith(Name("placeholder_a")),
+ NodeWith(Name("placeholder_b")))));
+
+ EXPECT_EQ(Explain(add.node(), NodeWith(Inputs())),
+ "\nexpected 0 inputs but node has 2");
+ EXPECT_EQ(
+ Explain(add.node(), NodeWith(Inputs(NodeWith(Name("blah")), _))),
+ "\ninput 0 does not match expected:\nname: blah, \nsource does not match "
+ "expected name: blah\n\t\nexpected name blah but found placeholder_a");
+ EXPECT_EQ(
+ Explain(add.node(), NodeWith(Inputs(_, NodeWith(Name("blah"))))),
+ "\ninput 1 does not match expected:\nname: blah, \nsource does not match "
+ "expected name: blah\n\t\nexpected name blah but found placeholder_b");
+}
+
+TEST(NodeMatchers, CheckControlDependence) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output placeholder_a =
+ ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+ Output placeholder_b =
+ ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+ Output placeholder_c =
+ ops::Placeholder(root.WithOpName("placeholder_c"), DT_FLOAT);
+ Output placeholder_d =
+ ops::Placeholder(root.WithOpName("placeholder_d"), DT_FLOAT);
+
+ root.graph()->AddControlEdge(placeholder_a.node(), placeholder_c.node());
+ root.graph()->AddControlEdge(placeholder_b.node(), placeholder_c.node());
+
+ EXPECT_THAT(placeholder_c.node(),
+ NodeWith(Name("placeholder_c"),
+ CtrlDeps(NodeWith(Name("placeholder_a")),
+ NodeWith(Name("placeholder_b")))));
+ EXPECT_THAT(placeholder_d.node(),
+ NodeWith(Name("placeholder_d"), CtrlDeps()));
+
+ EXPECT_EQ(
+ Explain(placeholder_c.node(), NodeWith(CtrlDeps())),
+ "ctrl_deps, which has 2 elements, does not match expected: is empty");
+ EXPECT_EQ(Explain(placeholder_d.node(), NodeWith(CtrlDeps(NodeWith()))),
+ "ctrl_deps does not match expected: has 1 element and that element "
+ "is any node");
+}
+
+TEST(NodeMatchers, ConstVaulue) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output placeholder =
+ ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
+ Output const_0d = ops::Const(root.WithOpName("const_0d"), 42);
+
+ Output const_2d = ops::Const(root.WithOpName("const_2d"), {{1, 2}, {4, 3}});
+
+ EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42)));
+ EXPECT_THAT(const_0d.node(), NodeWith(ConstantValue(42), Name("const_0d")));
+
+ EXPECT_THAT(const_2d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}})));
+
+ EXPECT_EQ(Explain(placeholder.node(), NodeWith(ConstantValue(42))),
+ "\nexpected op Const but found Placeholder");
+ EXPECT_EQ(
+ Explain(const_0d.node(), NodeWith(ConstantValue(43))),
+ "\nmismatch in constant tensor at index 0 expected = 43 actual = 42");
+ EXPECT_EQ(
+ Explain(const_0d.node(), NodeWith(ConstantValue({{1, 2}, {4, 3}}))),
+ "\nwas looking for tensor with 4 elements, found tensor with 1 elements");
+ EXPECT_EQ(
+ Explain(const_2d.node(), NodeWith(ConstantValue(42))),
+ "\nwas looking for tensor with 1 elements, found tensor with 4 elements");
+}
+
+TEST(NodeMatchers, AssignedDevice) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ Output placeholder_a =
+ ops::Placeholder(root.WithOpName("placeholder_a"), DT_FLOAT);
+ Output placeholder_b =
+ ops::Placeholder(root.WithOpName("placeholder_b"), DT_FLOAT);
+
+ Output assigned_add =
+ ops::Add(root.WithOpName("assigned_add"), placeholder_a, placeholder_b);
+ assigned_add.node()->set_assigned_device_name(
+ "/job:localhost/replica:0/task:0/device:CPU:0");
+
+ Output unassigned_add =
+ ops::Add(root.WithOpName("unassigned_add"), placeholder_a, placeholder_b);
+
+ EXPECT_THAT(
+ assigned_add.node(),
+ NodeWith(AssignedDevice("/job:localhost/replica:0/task:0/device:CPU:0")));
+ EXPECT_THAT(unassigned_add.node(), NodeWith(AssignedDevice("")));
+
+ EXPECT_EQ(Explain(unassigned_add.node(),
+ NodeWith(AssignedDevice(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))),
+ "\nexpected assigned_device "
+ "/job:localhost/replica:0/task:0/device:CPU:0 but found \"\"");
+}
+
+} // namespace
+} // namespace testing
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index 13804c6a05..f72224545b 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -4,9 +4,17 @@ package(
default_visibility = ["//tensorflow/compiler/tf2xla:internal"],
)
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
cc_library(
name = "xla_ops",
srcs = ["xla_ops.cc"],
deps = ["//tensorflow/core:framework"],
alwayslink = 1,
)
+
+tf_gen_op_wrapper_py(
+ name = "xla_ops_wrapper_py",
+ out = "xla_ops.py",
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
+)
diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc
index 1a29c3caab..bcd1a29b1f 100644
--- a/tensorflow/compiler/jit/ops/xla_ops.cc
+++ b/tensorflow/compiler/jit/ops/xla_ops.cc
@@ -51,4 +51,43 @@ REGISTER_OP("XlaClusterOutput")
"Operator that connects the output of an XLA computation to other "
"consumer graph nodes.");
+REGISTER_OP("_XlaCompile")
+ .Input("constants: Tconstants")
+ .Attr("Tconstants: list(type) >= 0")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Input("resources: Nresources * resource")
+ .Attr("Nresources: int >= 0")
+ .Output("key: string")
+ .Output("compilation_successful: bool")
+ .Attr("function: func")
+ // The compilation cache is stateful.
+ .SetIsStateful()
+ .Doc(R"(XLA Compile Op. For use by the XLA JIT only.
+
+Compiles a TensorFlow function into an XLA LocalExecutable and returns a key
+that _XlaRun can use to look up the LocalExecutable and execute it.
+
+key: A key that can be used to look up the local executable compiled by the
+ node and associated metadata.
+
+compilation_successful: True iff the compilation was successful. Always true
+for now.
+)");
+
+REGISTER_OP("_XlaRun")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Output("results: Tresults")
+ .Attr("Tresults: list(type) >= 0")
+ .Input("key: string")
+ // XLA random-number generation ops are stateful.
+ // TODO(phawkins): create stateful and non-stateful variants of _XlaRun.
+ .SetIsStateful()
+ .Doc(R"(XLA Run Op. For use by the XLA JIT only.
+
+Executes a TensorFlow function previously compiled into a LocalExecutable by an
+_XlaCompile op.
+)");
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 35872daa65..0feb73a89e 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -60,9 +60,9 @@ class FakeBinaryOp : public OpKernel {
void Compute(OpKernelContext* ctx) override { CHECK(false); }
};
-class FakeResourceVarUpdateOp : public OpKernel {
+class FakeResourceUpdateOp : public OpKernel {
public:
- explicit FakeResourceVarUpdateOp(OpKernelConstruction* context)
+ explicit FakeResourceUpdateOp(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* ctx) override { CHECK(false); }
@@ -74,10 +74,9 @@ REGISTER_KERNEL_BUILDER(Name("FakeBinary")
.HostMemory("host_out"),
FakeBinaryOp);
-REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate")
- .Device(DEVICE_CPU)
- .HostMemory("something_else"),
- FakeResourceVarUpdateOp);
+REGISTER_KERNEL_BUILDER(
+ Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"),
+ FakeResourceUpdateOp);
Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
FixupSourceAndSinkEdges(graph->get());
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 3ba48e8c31..b98c0cb028 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -34,6 +34,7 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) {
OptionalTensor& optional = variables[i];
optional.name = handle.name();
if (LookupResource(ctx, handle, &variable).ok()) {
+ core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
optional.present = true;
optional.value = *variable->tensor();
@@ -58,7 +59,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
- launch_context.PopulateInputs(ctx, result, variables);
+ launch_context.PopulateInputs(ctx, result, variables,
+ /*missing_ctx_input_prefix=*/0);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -79,7 +81,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
TF_RETURN_IF_ERROR(run_result.status());
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
- ctx, result, run_result.ConsumeValueOrDie()));
+ ctx, result, run_result.ConsumeValueOrDie(),
+ /*missing_ctx_input_prefix=*/0));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 7e159e3171..003c1d8081 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "Host" (CPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device.h"
@@ -65,10 +65,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 7> kAllXlaCpuTypes = {
- {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 51797def04..0824c4644e 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -373,7 +373,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":"
<< op_kernel->type_string();
- TracingDevice::Compute(op_kernel, context);
+ op_kernel->Compute(context);
}
void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
@@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
return status;
}
+void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) {
+ mutex_lock lock(mu_);
+ sync_on_completion_ = sync_on_completion;
+}
+
+bool XlaDevice::RequiresSyncOnCompletion() const {
+ mutex_lock lock(mu_);
+ return sync_on_completion_;
+}
+
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 92891ffa8c..0f06b3fc80 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice {
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
+ // Instructs this XlaDevice to return 'sync_on_completion' for
+ // RequiresSyncOnCompletion().
+ void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
+
+ bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
+
private:
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
@@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice {
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
- mutex mu_;
+ mutable mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
@@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice {
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
+
+ // True if the device requires XlaDevice::Sync to be called on completion
+ // regardless of status.
+ bool sync_on_completion_ GUARDED_BY(mu_) = false;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 49c8582682..2ccee79761 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -65,6 +65,16 @@ class XlaAssignVariableOp : public AsyncOpKernel {
.HostMemory("resources"), \
KERNEL);
+#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \
+ .Device(DEVICE) \
+ .HostMemory("constants") \
+ .HostMemory("resources"), \
+ KERNEL);
+
+#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
+
#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index ef4466f005..60979556a3 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
// operators using XLA via the XLA "CUDA" (GPU) backend.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -74,11 +74,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 8> kAllXlaGpuTypes = {
- {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
- DT_BFLOAT16}};
+constexpr std::array<DataType, 13> kAllXlaGpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index 4574559674..19e681af0c 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -15,7 +15,7 @@ limitations under the License.
// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
-#include "tensorflow/compiler/jit/kernels/xla_launch_op.h"
+#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -72,6 +72,10 @@ static bool OpFilter(KernelDef* kdef) { return true; }
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
kExecAllTypes);
+REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
+ kExecAllTypes);
+REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
+
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index affeab4a8c..4f6fc4e068 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -42,13 +42,14 @@ using xla::ShapedBuffer;
} // anonymous namespace
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables) {
+ OpKernelContext* ctx, absl::Span<const int> variables) {
std::map<int, OptionalTensor> snapshot;
for (int i : variables) {
Var* variable = nullptr;
ResourceHandle handle = HandleFromInput(ctx, i);
OptionalTensor& tensor = snapshot[i];
if (LookupResource(ctx, handle, &variable).ok()) {
+ core::ScopedUnref scoped_unref(variable);
tf_shared_lock lock(*variable->mu());
tensor.name = handle.name();
tensor.present = true;
@@ -133,7 +134,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- const std::map<int, OptionalTensor>& variables) {
+ const std::map<int, OptionalTensor>& variables,
+ int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
@@ -145,12 +147,13 @@ void XlaComputationLaunchContext::PopulateInputs(
const Tensor* t;
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->input_mapping[i];
+ DCHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = kernel->xla_input_shapes[i];
if (variables.count(arg_num)) {
t = &(variables.at(arg_num).value);
CHECK(t);
} else {
- t = &(ctx->input(arg_num));
+ t = &(ctx->input(arg_num - missing_ctx_input_prefix));
}
if (use_multiple_streams_) {
@@ -187,7 +190,7 @@ void XlaComputationLaunchContext::PopulateInputs(
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
- ScopedShapedBuffer output) {
+ ScopedShapedBuffer output, int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
@@ -275,6 +278,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
+ TF_RET_CHECK(kernel->outputs[i].input_index >= 0)
+ << "Invalid input for outputs " << i;
ctx->set_output(i, ctx->input(kernel->outputs[i].input_index));
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
@@ -313,7 +318,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
+ int actual_input_index = write.input_index - missing_ctx_input_prefix;
+ if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
@@ -323,7 +329,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index), &variable,
+ ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 7ac275fab8..326d70a027 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
class XlaAllocator;
@@ -43,7 +44,7 @@ class XlaAllocator;
// resource variable is not initialized, the corresponding OptionalTensor
// will have its `present` field set to false.
std::map<int, OptionalTensor> SnapshotResourceVariables(
- OpKernelContext* ctx, const std::vector<int>& variables);
+ OpKernelContext* ctx, absl::Span<const int> variables);
// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
// Assumes that the Tensorflow allocator permits asynchronous deallocation:
@@ -88,14 +89,24 @@ class XlaComputationLaunchContext {
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.
+ //
+ // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
+ // missing and adjusts input indices accordingly. All elements in kernel's
+ // input_mapping must be greater than or equal to `missing_ctx_input_prefix`
+ // (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- const std::map<int, OptionalTensor>& variables);
+ const std::map<int, OptionalTensor>& variables,
+ int missing_ctx_input_prefix);
// Given the XLA output in `output`, populate all outputs of `ctx`.
+ //
+ // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
+ // missing and adjusts input indices accordingly.
Status PopulateOutputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
- xla::ScopedShapedBuffer output);
+ xla::ScopedShapedBuffer output,
+ int missing_ctx_input_prefix);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 050d827a09..3cf74fa788 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -277,9 +277,10 @@ tf_xla_py_test(
],
)
+# This test is large because occasionally the cpu test is long for testConcatLargeNumberOfTensors
tf_xla_py_test(
name = "concat_ops_test",
- size = "medium",
+ size = "large",
srcs = ["concat_ops_test.py"],
deps = [
":xla_test",
@@ -581,6 +582,7 @@ tf_xla_py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -976,7 +978,7 @@ tf_xla_py_test(
name = "gather_test",
size = "medium",
srcs = ["gather_test.py"],
- tags = ["noasan"], # times out, http://b/78599043
+ tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@@ -1196,8 +1198,21 @@ tf_xla_py_test(
)
tf_xla_py_test(
- name = "xla_ops_test",
+ name = "quantized_ops_test",
size = "small",
+ srcs = ["quantized_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "xla_ops_test",
+ size = "medium",
srcs = ["xla_ops_test.py"],
disabled_backends = ["cpu_ondemand"],
deps = [
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py
index 4155342787..68f52e796c 100644
--- a/tensorflow/compiler/tests/argminmax_test.py
+++ b/tensorflow/compiler/tests/argminmax_test.py
@@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase):
def testArgMinMax(self):
# Complex numbers do not support argmin/argmax.
- minmax_types = set(self.numeric_types) - set(self.complex_types)
+ minmax_types = self.all_types & {np.int32, np.int64}
for dtype in minmax_types:
# output_type is a numpy data type that is used to specify the desired
# output type of the op as well as to convert the Python number to the
# array scalar of the type.
- for output_type in self.int_types:
+ for output_type in minmax_types:
self._assertOpOutputMatchesExpected(
math_ops.argmax,
axis=0,
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 17280e445b..e219cf3d88 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
equality_test=self.ListsAreClose)
def testIntOps(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types:
self._testBinary(
gen_math_ops.truncate_div,
np.array([3, 3, -1, -9, -8], dtype=dtype),
@@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
- if dtype not in self.complex_types: # min/max not supported for complex
+ # min/max not supported for complex
+ if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
math_ops.maximum,
np.array([1, 2], dtype=dtype),
@@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([[70], [14]], dtype=dtype))
# Complex support for squared_difference is incidental, see b/68205550
- if dtype not in self.complex_types:
+ if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
math_ops.squared_difference,
np.array([1, 2], dtype=dtype),
@@ -559,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(2),
expected=np.array([[5], [2]], dtype=dtype))
+ if dtype in [np.float32, np.float64]:
+ nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1)
+ divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24)
+ np_result = np.true_divide(nums, divs)
+ np_result[:, divs[0] == 0] = 0
+ self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result)
+
if dtype not in self.complex_types: # floordiv unsupported for complex.
self._testBinary(
gen_math_ops.floor_div,
@@ -567,7 +575,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
def testIntDivision(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types:
self._testDivision(dtype)
def testFloatDivision(self):
@@ -588,7 +596,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, 1, -1, 0], dtype=dtype))
def testIntRemainder(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types - {np.int8}:
self._testRemainder(dtype)
def testFloatRemainder(self):
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index 7b114d4f85..1d3979b21b 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -2,90 +2,103 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
def all_backends():
- b = ["cpu"] + plugins.keys()
- if cuda_is_configured():
- return b + ["gpu"]
- else:
- return b
+ b = ["cpu"] + plugins.keys()
+ if cuda_is_configured():
+ return b + ["gpu"]
+ else:
+ return b
-def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
- disabled_backends=None, **kwargs):
- """Generates py_test targets, one per XLA backend.
+def tf_xla_py_test(
+ name,
+ srcs = [],
+ deps = [],
+ tags = [],
+ data = [],
+ main = None,
+ disabled_backends = None,
+ **kwargs):
+ """Generates py_test targets, one per XLA backend.
- This rule generates py_test() targets named name_backend, for each backend
- in all_backends(). The rule also generates a test suite with named `name` that
- tests all backends for the test.
+ This rule generates py_test() targets named name_backend, for each backend
+ in all_backends(). The rule also generates a test suite with named `name` that
+ tests all backends for the test.
- For example, the following rule generates test cases foo_test_cpu,
- foo_test_gpu, and a test suite name foo_test that tests both.
- tf_xla_py_test(
- name="foo_test",
- srcs="foo_test.py",
- deps=[...],
- )
+ For example, the following rule generates test cases foo_test_cpu,
+ foo_test_gpu, and a test suite name foo_test that tests both.
+ tf_xla_py_test(
+ name="foo_test",
+ srcs="foo_test.py",
+ deps=[...],
+ )
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- deps: Dependencies of the target.
- tags: Tags to apply to the generated targets.
- data: Data dependencies of the target.
- main: Same as py_test's main attribute.
- disabled_backends: A list of backends that should not be tested. Supported
- values include "cpu" and "gpu". If not specified, defaults to None.
- **kwargs: keyword arguments passed onto the generated py_test() rules.
- """
- if disabled_backends == None:
- disabled_backends = []
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ tags: Tags to apply to the generated targets.
+ data: Data dependencies of the target.
+ main: Same as py_test's main attribute.
+ disabled_backends: A list of backends that should not be tested. Supported
+ values include "cpu" and "gpu". If not specified, defaults to None.
+ **kwargs: keyword arguments passed onto the generated py_test() rules.
+ """
+ if disabled_backends == None:
+ disabled_backends = []
- enabled_backends = [b for b in all_backends() if b not in disabled_backends]
- test_names = []
- for backend in enabled_backends:
- test_name = "{}_{}".format(name, backend)
- backend_tags = ["tf_xla_{}".format(backend)]
- backend_args = []
- backend_deps = []
- backend_data = []
- if backend == "cpu":
- backend_args += [
- "--test_device=XLA_CPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
- ]
- elif backend == "gpu":
- backend_args += [
- "--test_device=XLA_GPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16"
- ]
- backend_tags += ["requires-gpu-sm35"]
- elif backend in plugins:
- backend_args += ["--test_device=" + plugins[backend]["device"],
- "--types=" + plugins[backend]["types"]]
- backend_tags += plugins[backend]["tags"]
- backend_args += plugins[backend]["args"]
- backend_deps += plugins[backend]["deps"]
- backend_data += plugins[backend]["data"]
- else:
- fail("Unknown backend {}".format(backend))
+ enabled_backends = [b for b in all_backends() if b not in disabled_backends]
+ test_names = []
+ for backend in enabled_backends:
+ test_name = "{}_{}".format(name, backend)
+ backend_tags = ["tf_xla_{}".format(backend)]
+ backend_args = []
+ backend_deps = []
+ backend_data = []
+ if backend == "cpu":
+ backend_args += [
+ "--test_device=XLA_CPU",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+ ]
+ elif backend == "gpu":
+ backend_args += [
+ "--test_device=XLA_GPU",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+ ]
+ backend_tags += tf_cuda_tests_tags()
+ elif backend in plugins:
+ backend_args += [
+ "--test_device=" + plugins[backend]["device"],
+ "--types=" + plugins[backend]["types"],
+ ]
+ backend_tags += plugins[backend]["tags"]
+ backend_args += plugins[backend]["args"]
+ backend_deps += plugins[backend]["deps"]
+ backend_data += plugins[backend]["data"]
+ else:
+ fail("Unknown backend {}".format(backend))
- native.py_test(
- name=test_name,
- srcs=srcs,
- srcs_version="PY2AND3",
- args=backend_args,
- main="{}.py".format(name) if main == None else main,
- data=data + backend_data,
- deps=deps + backend_deps,
- tags=tags + backend_tags,
- **kwargs
- )
- test_names.append(test_name)
- native.test_suite(name=name, tests=test_names)
+ native.py_test(
+ name = test_name,
+ srcs = srcs,
+ srcs_version = "PY2AND3",
+ args = backend_args,
+ main = "{}.py".format(name) if main == None else main,
+ data = data + backend_data,
+ deps = deps + backend_deps,
+ tags = tags + backend_tags,
+ **kwargs
+ )
+ test_names.append(test_name)
+ native.test_suite(name = name, tests = test_names)
-def generate_backend_suites(backends=[]):
- """Generates per-backend test_suites that run all tests for a backend."""
- if not backends:
- backends = all_backends()
- for backend in backends:
- native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])
+def generate_backend_suites(backends = []):
+ """Generates per-backend test_suites that run all tests for a backend."""
+ if not backends:
+ backends = all_backends()
+ for backend in backends:
+ native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend])
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index 37e5318bb5..2d225ad226 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -291,6 +291,41 @@ class ConcatTest(xla_test.XLATestCase):
ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
array_ops.concat([scalar, scalar, scalar], dim)
+ # The purpose of this is to ensure that XLA on GPU will not run out of memory
+ # with too many arguments.
+ def testConcatLargeNumberOfTensors(self):
+ with self.cached_session():
+ with self.test_scope():
+ for concat_dim in range(2):
+ params = {}
+ p = []
+ shape = np.array([7, 13])
+ num_tensors = 1001
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ placeholder = array_ops.placeholder(
+ dtypes.float32, shape=input_shape)
+ p.append(placeholder)
+ params[placeholder] = np.random.rand(*input_shape).astype(
+ np.float32)
+
+ concat_inputs = p
+ c = array_ops.concat(concat_inputs, concat_dim)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ index = [slice(0, params[p[i]].shape[j]) for j in np.arange(2)]
+ index[concat_dim] = slice(
+ cur_offset, cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ self.assertAllEqual(result[index], params[p[i]])
+
class ConcatOffsetTest(xla_test.XLATestCase):
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 0af74c2d8f..9390870e07 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -45,17 +45,21 @@ def InLabels(labels, substr):
return any([substr in x for x in labels])
-def XlaLaunchOpCount(labels):
- """Count how many XlaLaunch labels are present."""
- return sum("XlaLaunch(" in x for x in labels)
+class DenseLayerTest(test.TestCase):
+ def countXlaOps(self, labels):
+ """Count how many XlaCompile/XlaRun labels are present."""
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+ return xla_run_count
-class DenseLayerTest(test.TestCase):
def testDenseLayerAutoJit(self):
"""Tests dense layer compilation in auto-jit mode.
- Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
+ Dense layer should be compiled into a single XlaCompile/XlaRun op pair in
+ auto-jit mode.
"""
os.environ["TF_XLA_FLAGS"] = (
@@ -77,14 +81,14 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
Dense layer with static shape input tensor should be compiled into a single
- XlaLaunch op by XLA.
+ XlaCompile/XlaRun op pair by XLA.
"""
with self.cached_session() as sess:
@@ -101,7 +105,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(1, XlaLaunchOpCount(labels))
+ self.assertEqual(1, self.countXlaOps(labels))
# No need to check whether ListDiff is compiled or not because ListDiff op
# is not used when input tensor shape is fully defined.
@@ -111,7 +115,8 @@ class DenseLayerTest(test.TestCase):
Dense layer uses shape op to get shape of input tensor if its shape is not
fully defined. XLA does not cluster shape op with other operators. But in
experimental_jit_scope, XLA is forced to compile shape op into its own
- cluster, causing dense layer to be split into TWO XlaLaunch ops.
+ cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op
+ pairs.
"""
with self.cached_session() as sess:
@@ -128,7 +133,7 @@ class DenseLayerTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = GetRunMetadataLabels(run_metadata)
- self.assertEqual(2, XlaLaunchOpCount(labels))
+ self.assertEqual(2, self.countXlaOps(labels))
self.assertFalse(InLabels(labels, "MatMult"))
diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py
index 089d95daab..a38e1edafe 100644
--- a/tensorflow/compiler/tests/gather_test.py
+++ b/tensorflow/compiler/tests/gather_test.py
@@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase):
indices_tf = constant_op.constant(indices)
gather_t = array_ops.gather(params, indices_tf)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- np_val = params_np[indices]
+ np_val = constant_op.constant(params_np[indices])
self.assertAllEqual(np_val, gather_val)
def testScalar2D(self):
@@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant(2)
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, 2, axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, 2, axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32(self):
@@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase):
indices = constant_op.constant([0, 1, 0, 2])
gather_t = array_ops.gather(params, indices, axis=axis)
gather_val = session.run(gather_t, feed_dict={params: params_np})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testSimpleTwoD32_Int64Indices(self):
@@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase):
params: params_np,
indices: indices_np
})
- expected = np.take(params_np, [0, 1, 0, 2], axis=axis)
+ expected = constant_op.constant(
+ np.take(params_np, [0, 1, 0, 2], axis=axis), dtype)
self.assertAllEqual(expected, gather_val)
def testHigherRank(self):
@@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase):
tf_indices = constant_op.constant(indices, dtype=dtypes.int32)
gather = array_ops.gather(tf_params, tf_indices, axis=axis)
gather_value = sess.run(gather, feed_dict={tf_params: params})
- gather_np = np.take(params, indices, axis=axis)
+ gather_np = constant_op.constant(
+ np.take(params, indices, axis=axis), dtype)
self.assertAllEqual(gather_np, gather_value)
def testIndicesWithDifferentDimensions(self):
diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py
index 6fe5a66e0e..bbe746e28f 100644
--- a/tensorflow/compiler/tests/image_ops_test.py
+++ b/tensorflow/compiler/tests/image_ops_test.py
@@ -605,10 +605,6 @@ class ResizeBilinearTest(xla_test.XLATestCase):
class NonMaxSuppressionTest(xla_test.XLATestCase):
def testNMS128From1024(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
num_boxes = 1024
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
@@ -644,10 +640,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertEqual(indices_tf.size, max_output_size)
def testNMS3From6Boxes(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
# Three boxes are selected based on IOU.
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
@@ -693,10 +685,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
# Three boxes are selected based on IOU.
# One is filtered out by score threshold.
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
with compat.forward_compatibility_horizon(2018, 8, 8):
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 0839fb123e..de68ff0e32 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -77,11 +77,11 @@ def InLabels(labels, substr):
return any([substr in x for x in labels])
-def MetadataHasXlaLaunch(run_metadata):
- """Returns true if there is a XlaLaunch kernel in run_metadata's timeline."""
+def MetadataHasXlaOp(run_metadata):
+ """Returns true if there are XlaRun kernels in run_metadata's timeline."""
# TODO(phawkins): find a less hacky way to test whether a kernel ran.
- return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch")
+ return InLabels(RunMetadataLabels(run_metadata), "XlaRun")
class JitLaunchTest(test.TestCase):
@@ -90,9 +90,10 @@ class JitLaunchTest(test.TestCase):
# Verifies that the outputs match and that XLA was invoked. 'fn' must take
# the same number of tensors as arguments that are in 'args', and must return
# a tuple of output tensors.
- # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node
- # actually ran. However, it is sometimes possible for XlaLaunch ops to be
- # constant-folded away, so the check is optional.
+ #
+ # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
+ # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
+ # ops to be constant-folded away, so the check is optional.
def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
placeholders = []
@@ -115,7 +116,7 @@ class JitLaunchTest(test.TestCase):
print("Compiled Result {}".format(compiled))
if require_kernel_launch:
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
direct = sess.run(direct_op, feeds)
print("Direct Result {}".format(direct))
@@ -149,10 +150,10 @@ class JitLaunchTest(test.TestCase):
y = math_ops.add(x, x)
return y, y
- # Exercises compling a function (say, Foo) which calls another
- # function (say, Bar) which is not inlined. When the compiler compiles
- # Foo, it needs to symbolic execute Bar correctly regardless whether
- # Bar is inlined or not.
+ # Exercises compiling a function (say, Foo) which calls another function
+ # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
+ # to symbolically execute Bar correctly regardless of whether Bar is inlined
+ # or not.
# TODO(b/36139787): Re-enable this test when noinline works again.
# Tests compiled=True and noinline=True.
@@ -259,7 +260,7 @@ class JitLaunchTest(test.TestCase):
# TODO(phawkins): really we would like to test that there were exactly
# two kernel launches. However, we have no reliable way to determine
# that.
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
expected = np.square(np.dot(dx, dw) + db)
self.assertAllClose(expected, output, rtol=1e-1)
@@ -289,7 +290,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
def testIgnoredArguments(self):
@@ -313,7 +314,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(28, out)
def testLoops(self):
@@ -331,7 +332,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(95), rtol=1e-1)
def testCond(self):
@@ -356,7 +357,7 @@ class XlaCompilationTest(test.TestCase):
run_metadata=run_metadata,
options=config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE))
- self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assert_(MetadataHasXlaOp(run_metadata))
self.assertAllClose(result, np.float32(6), rtol=1e-1)
def testNestedFunction(self):
@@ -441,14 +442,16 @@ class XlaCompilationTest(test.TestCase):
self.assertFalse(InLabels(labels, "Log"))
self.assertTrue(InLabels(labels, "Reciprocal"))
self.assertTrue(InLabels(labels, "Mul"))
- self.assertFalse(InLabels(labels, "XlaLaunch"))
+ self.assertFalse(InLabels(labels, "XlaCompile"))
+ self.assertFalse(InLabels(labels, "XlaRun"))
- # Compile the backprop. One XlaLaunch.
+ # Compile the backprop. One XlaCompile/XlaRun pair.
labels = _Run(compiled=True)
self.assertFalse(InLabels(labels, "Log"))
self.assertFalse(InLabels(labels, "Reciprocal"))
self.assertFalse(InLabels(labels, "Mul"))
- self.assertTrue(InLabels(labels, "XlaLaunch"))
+ self.assertTrue(InLabels(labels, "XlaCompile"))
+ self.assertTrue(InLabels(labels, "XlaRun"))
class ElementWiseFusionTest(test.TestCase):
@@ -482,9 +485,12 @@ class ElementWiseFusionTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE))
labels = RunMetadataLabels(run_metadata)
- count = sum("XlaLaunch(" in x for x in labels)
- return output, count
+ xla_compile_count = sum("XlaCompile(" in x for x in labels)
+ xla_run_count = sum("XlaRun(" in x for x in labels)
+ self.assertEqual(xla_compile_count, xla_run_count)
+
+ return output, xla_run_count
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py
index 9222db4b7e..c61965b97f 100644
--- a/tensorflow/compiler/tests/matrix_band_part_test.py
+++ b/tensorflow/compiler/tests/matrix_band_part_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.compiler.tests import xla_test
@@ -26,38 +27,167 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MatrixBandPartTest(xla_test.XLATestCase):
+class MatrixBandPartTest(xla_test.XLATestCase, parameterized.TestCase):
- def _testMatrixBandPart(self, dtype, shape):
- with self.cached_session():
- batch_shape = shape[:-2]
- mat = np.ones(shape).astype(dtype)
- batch_mat = np.tile(mat, batch_shape + [1, 1])
- for lower in -1, 0, 1, shape[-2] - 1:
- for upper in -1, 0, 1, shape[-1] - 1:
- band_np = mat
- if lower >= 0:
- band_np = np.triu(band_np, -lower)
- if upper >= 0:
- band_np = np.tril(band_np, upper)
- if batch_shape:
- band_np = np.tile(band_np, batch_shape + [1, 1])
-
- placeholder = array_ops.placeholder(dtype)
- with self.test_scope():
- band = array_ops.matrix_band_part(
- placeholder,
- constant_op.constant(lower, dtype=dtypes.int32),
- constant_op.constant(upper, dtype=dtypes.int32))
- feed_dict = {placeholder: batch_mat}
- self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
-
- def testMatrixBandPart(self):
+ @parameterized.parameters(
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [],
+ 'rows': 7,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [2,],
+ 'rows': 7,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 1,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 2,
+ 'cols': 7
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 1
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 2
+ },
+ {
+ 'batch_shape': [1, 3, 2],
+ 'rows': 7,
+ 'cols': 7
+ },
+ )
+ def testMatrixBandPart(self, batch_shape, rows, cols):
for dtype in self.float_types:
- for batch_shape in [[], [2,], [1, 3, 2]]:
- for rows in 1, 2, 7:
- for cols in 1, 2, 7:
- self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
+ with self.cached_session():
+ mat = np.ones(batch_shape + [rows, cols]).astype(dtype)
+ batch_mat = np.tile(mat, batch_shape + [1, 1])
+ for lower in -1, 0, 1, rows - 1:
+ for upper in -1, 0, 1, cols - 1:
+ band_np = mat
+ if lower >= 0:
+ band_np = np.triu(band_np, -lower)
+ if upper >= 0:
+ band_np = np.tril(band_np, upper)
+ if batch_shape:
+ band_np = np.tile(band_np, batch_shape + [1, 1])
+
+ placeholder = array_ops.placeholder(dtype)
+ with self.test_scope():
+ band = array_ops.matrix_band_part(
+ placeholder, constant_op.constant(lower, dtype=dtypes.int32),
+ constant_op.constant(upper, dtype=dtypes.int32))
+ feed_dict = {placeholder: batch_mat}
+ self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py
new file mode 100644
index 0000000000..80c338513b
--- /dev/null
+++ b/tensorflow/compiler/tests/quantized_ops_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for quantized operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class QuantizedOpsTest(xla_test.XLATestCase):
+
+ # Verify that quantized types can be clustered by XLA.
+ def testQuantizedTypeRoundtrip(self):
+ with self.cached_session() as session:
+ for dtype in self.quantized_tf_types:
+ in_values = np.array([1, 2, 3, 4, 5, 6])
+ expected = [[1, 2], [3, 4], [5, 6]]
+ with self.test_scope():
+ p = array_ops.placeholder(dtype=dtypes.int32)
+ x = math_ops.cast(p, dtype)
+ x = array_ops.reshape(x, [3, 2])
+
+ value = session.run(x, {p: in_values})
+ self.assertAllEqual(value, expected)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 6e18344117..36ef6ed5fe 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -35,7 +35,8 @@ class RandomOpsTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def _random_types(self):
- return set(self.numeric_types) - set(self.complex_types)
+ return set(self.numeric_types) - set(
+ self.complex_types) - {np.uint8, np.int8}
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
@@ -68,9 +69,8 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.random_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- dtype = dtypes.float32
- self._testRngIsNotConstant(rng, dtype)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testRandomUniformIsInRange(self):
for dtype in self._random_types():
@@ -92,13 +92,13 @@ class RandomOpsTest(xla_test.XLATestCase):
def rng(dtype):
return random_ops.truncated_normal(shape=[2], dtype=dtype)
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- self._testRngIsNotConstant(rng, dtypes.float32)
+ for dtype in self._random_types() & self.float_types:
+ self._testRngIsNotConstant(rng, dtype)
def testTruncatedNormalIsInRange(self):
count = 10000000
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ # TODO(b/34339814): make this test work with 16 bit float types.
+ for dtype in self._random_types() & {dtypes.float32, dtypes.float64}:
with self.cached_session() as sess:
with self.test_scope():
x = random_ops.truncated_normal(shape=[count], dtype=dtype)
@@ -144,9 +144,6 @@ class RandomOpsTest(xla_test.XLATestCase):
self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3)
def testShuffle1d(self):
- # TODO(b/26783907): this test requires the CPU backend to implement sort.
- if self.device in ["XLA_CPU"]:
- return
with self.cached_session() as sess:
with self.test_scope():
x = math_ops.range(1 << 16)
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index 60c2337743..abc822ef36 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase):
def testSeqLength(self):
for dtype in self.all_types:
- for seq_dtype in self.int_types:
+ for seq_dtype in self.all_types & {np.int32, np.int64}:
self._testBasic(dtype, seq_dtype)
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index 51c04b5c47..dbf4beb693 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -48,10 +48,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
self.assertAllClose(v, result, rtol=1e-3)
def testSort(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
for dtype in supported_types.intersection(self.numeric_types):
x = np.arange(101, dtype=dtype)
@@ -60,10 +56,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
def testTopK(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@@ -89,10 +81,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
expected=[x[indices].astype(dtype), indices])
def testTopK2D(self):
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
@@ -122,10 +110,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
@@ -144,10 +128,6 @@ class XlaSortOpTest(xla_test.XLATestCase):
def testTopKInfinities(self):
"""Tests that positive and negative infinity sort correctly."""
- # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU.
- if self.device in ["XLA_CPU", "XLA_GPU"]:
- return
-
# Only bfloat16 is implemented.
bfloat16 = dtypes.bfloat16.as_numpy_dtype
if bfloat16 not in self.numeric_types:
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 1bea7d9355..f3861043b2 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
"""Test cases for stateless random-number generator operators."""
def _random_types(self):
- return [dtypes.float32]
+ return self.float_types & {dtypes.float32, dtypes.float64}
def testDeterminism(self):
# Stateless values should be equal iff the seeds are equal (roughly)
@@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
self.assertTrue(self._anderson_darling(y) < 2.492)
def testTruncatedNormalIsInRange(self):
- # TODO(b/34339814): implement inverse erf support for non-F32 types.
- for dtype in [dtypes.float32]:
+ for dtype in self._random_types():
with self.cached_session() as sess, self.test_scope():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
n = 10000000
@@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
# Department of Scientific Computing website. Florida State University.
expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma
actual_mean = np.mean(y)
- self.assertAllClose(actual_mean, expected_mean, atol=2e-4)
+ self.assertAllClose(actual_mean, expected_mean, atol=5e-4)
expected_median = mu + probit(
(normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 55a992195f..98a07709c6 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -122,8 +122,7 @@ class TernaryOpsTest(xla_test.XLATestCase):
expected=np.array([[2], [5]], dtype=dtype))
def testClipByValue(self):
- # TODO(b/78258593): enable integer types here too.
- for dtype in self.float_types:
+ for dtype in self.numeric_types - self.complex_types:
test_cases = [
(np.array([2, 4, 5], dtype=dtype), dtype(7)), #
(dtype(1), np.array([2, 4, 5], dtype=dtype)), #
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 5b0e57f83f..77f6eee0cf 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
self.assertAllClose(result[i], expected[i], rtol, atol)
def testAllTypeOps(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(
array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype),
np.array(
@@ -158,9 +158,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
def testFloatOps(self):
for dtype in self.float_types:
- # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018.
- if dtype == np.float16 and self.device == "XLA_CPU":
- continue
x = np.arange(-0.90, 0.90, 0.25)
self._assertOpOutputMatchesExpected(
math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype))
@@ -633,7 +630,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
def testNumericOps(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(
math_ops.abs,
np.array([[2, -1]], dtype=dtype),
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 0f3843dc1e..4cf88fc523 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import googletest
@@ -34,7 +35,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
def _assertOpOutputMatchesExpected(self, op, args, expected,
equality_fn=None):
- with self.cached_session() as session:
+ with self.test_session() as session:
with self.test_scope():
placeholders = [
array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
@@ -180,7 +181,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
dtype=dtype))
def testNeg(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.uint8, np.int8}:
self._assertOpOutputMatchesExpected(
xla.neg,
args=(np.array([1, 2, 3], dtype=dtype),),
@@ -296,6 +297,44 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
self._assertOpOutputMatchesExpected(
lambda x: xla.transpose(x, [1, 0]), args=(v,), expected=v.T)
+ def testDynamicSlice(self):
+ for dtype in self.numeric_types:
+ self._assertOpOutputMatchesExpected(
+ xla.dynamic_slice,
+ args=(np.arange(1000,
+ dtype=np.int32).astype(dtype).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3, 2])),
+ expected=np.array(
+ np.array([[[573, 574], [583, 584], [593, 594]],
+ [[673, 674], [683, 684], [693, 694]]]),
+ dtype=dtype))
+
+ def testDynamicSliceWithIncorrectStartIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7]), np.array([2, 3, 4]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^start_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and start_indices has shape \[2\].*'))
+
+ def testDynamicSliceWithIncorrectSizeIndicesShape(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = xla.dynamic_slice(
+ np.arange(1000, dtype=np.int32).reshape([10, 10, 10]),
+ np.array([5, 7, 3]), np.array([2, 3]))
+ with self.assertRaises(errors.InvalidArgumentError) as invalid_arg_error:
+ session.run(output)
+ self.assertRegexpMatches(
+ invalid_arg_error.exception.message,
+ (r'^size_indices must be a vector with length equal to input rank, '
+ r'but input rank is 3 and size_indices has shape \[2\].*'))
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 88827cb53b..98a41981cf 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -97,10 +97,23 @@ class XLATestCase(test.TestCase):
])
self._numeric_tf_types = set(
self.int_tf_types | self._float_tf_types | self.complex_tf_types)
-
- self._all_types = set(
- [dtype.as_numpy_dtype for dtype in self._all_tf_types])
+ self.quantized_tf_types = set(
+ dtype for dtype in self._all_tf_types if dtype.is_quantized)
+
+ # Quantized types don't have a numpy equivalent, include them in
+ # all_tf_types but not in all_types.
+ # TODO(b/115960798): Parametrize tests on TF types instead of numpy types
+ # and remove all_types.
+ self._all_types = set(dtype.as_numpy_dtype
+ for dtype in self._all_tf_types
+ if not dtype.is_quantized)
self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
+ self.signed_int_types = set(dtype.as_numpy_dtype
+ for dtype in self.int_tf_types
+ if not dtype.is_unsigned)
+ self.unsigned_int_types = set(dtype.as_numpy_dtype
+ for dtype in self.int_tf_types
+ if dtype.is_unsigned)
self._float_types = set(
[dtype.as_numpy_dtype for dtype in self._float_tf_types])
self.complex_types = set([
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 74b131e07e..ba1e3b2b4f 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -76,6 +76,7 @@ cc_library(
deps = [
":common",
":dump_graph",
+ ":functionalize_control_flow",
":tf2xla_proto",
":tf2xla_util",
":xla_compiler",
@@ -188,7 +189,6 @@ cc_library(
deps = [
":common",
":dump_graph",
- ":functionalize_control_flow",
":host_compute_metadata_proto",
":sharding_util",
":side_effect_util",
@@ -284,6 +284,7 @@ cc_library(
deps = [
":sharding_util",
":tf2xla_proto",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@@ -479,6 +480,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -506,12 +508,24 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
)
cc_library(
+ name = "functionalize_control_flow_pass_registration",
+ srcs = [
+ "functionalize_control_flow_pass_registration.cc",
+ ],
+ deps = [
+ ":functionalize_control_flow",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "functionalize_while",
srcs = [
"functionalize_while.cc",
@@ -520,6 +534,7 @@ cc_library(
"functionalize_while.h",
],
deps = [
+ ":functionalize_cond",
":functionalize_control_flow_util",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
@@ -530,6 +545,7 @@ cc_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
@@ -544,6 +560,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
+ "//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/cc:xla_ops",
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index 8ac5eb5df9..ea8d1b3d14 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -31,9 +31,7 @@ cc_library(
tf_gen_op_wrapper_cc(
name = "xla_jit_op_gen",
out_ops_file = "ops/xla_jit_op",
- deps = [
- "//tensorflow/compiler/jit/ops:xla_ops",
- ],
+ deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
cc_library(
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 922ae7c79a..027ca6d2d2 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -29,14 +29,6 @@ Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
std::function<bool(const Edge&)> edge_filter) {
- // Operators that don't look at the data of their inputs, just the shapes.
- const std::unordered_set<string> metadata_ops = {
- "Rank",
- "Shape",
- "ShapeN",
- "Size",
- };
-
std::vector<bool> compile_time_const_nodes_impl;
if (compile_time_const_nodes) {
CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
@@ -50,7 +42,9 @@ Status BackwardsConstAnalysis(const Graph& g,
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
- if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return;
+ if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
+ return;
+ }
// If this node must be const, and it isn't a metadata op, then all of its
// parents must be const.
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index 0911550f1f..db256e577a 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
using xla::StatusOr;
@@ -217,10 +218,6 @@ void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) {
added_node_ancestorid_mapping_[node->id()] = id;
}
-const StateMap::CondState& StateMap::LookupState(const Node* node) const {
- return *LookupCondId(node);
-}
-
void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); }
string StateMap::CondStateToString(const Node* node) const {
@@ -642,7 +639,7 @@ Status Conditional::ExtractBodies(Graph* graph) {
Status Conditional::BuildIfNode(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "Build cond function for " << name();
- NodeDefBuilder builder(name(), "If");
+ NodeDefBuilder builder(name(), "If", library);
const string branch_name[] = {"else_branch", "then_branch"};
for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) {
int branch_index = static_cast<int>(branch);
@@ -791,7 +788,6 @@ Status Conditional::BuildAndReplace(Graph* graph,
TF_RETURN_IF_ERROR(AddInputEdges(graph));
TF_RETURN_IF_ERROR(AddOutputEdges(graph));
TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_));
- for (Node* m : merges_) state_map_->MarkDead(m);
// Check that the if_node doesn't feed into itself.
TF_RETURN_WITH_CONTEXT_IF_ERROR(
@@ -1056,7 +1052,6 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) {
" has no non-dead inputs.");
}
state_map_.MarkDead(node);
- delete_nodes_.push_back(node->id());
VLOG(5) << "removing redundant merge: " << node->name();
while (!node->out_edges().empty()) {
const Edge* oe = *node->out_edges().begin();
@@ -1132,7 +1127,6 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) {
}
} else if (BranchType(switch_branch) != b) {
state_map_.MarkDead(dst_node);
- delete_nodes_.push_back(dst_node->id());
continue;
}
graph_->AddEdge(
@@ -1154,7 +1148,7 @@ Status FunctionalizeCond::DetermineStates(std::vector<Node*> rev_topo_order) {
VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst)
<< " @ " << state_map_.AncestorStateToString(dst);
- if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it");
+ if (VLOG_IS_ON(10)) DumpGraphWithCondState("it");
}
return Status::OK();
}
@@ -1184,23 +1178,62 @@ Status FunctionalizeCond::DetermineAncestorState(Node* dst) {
return Status::OK();
}
-void FunctionalizeCond::DeleteReachableNodes() {
+void FunctionalizeCond::DeleteReachableAndDeadNodes(
+ const std::vector<int>& switch_ids, const std::vector<Node*>& merge_order) {
// Delete all nodes that have been extracted or are reachable from
// deleted/dead nodes. The input and outgoing edges should have already been
// removed.
+ std::deque<int> delete_nodes;
std::vector<bool> deleted(graph_->num_node_ids(), false);
// Don't try to delete source or sink nodes.
deleted[graph_->kSourceId] = true;
deleted[graph_->kSinkId] = true;
- while (!delete_nodes_.empty()) {
- int d_id = delete_nodes_.front();
- delete_nodes_.pop_front();
+
+ // All remaining Switch nodes are not reachable from a Merge node and
+ // removed. This is to account for dead Switch nodes.
+ for (int s_id : switch_ids) {
+ Node* s = graph_->FindNodeId(s_id);
+ if (s == nullptr) continue;
+ for (const Edge* e : s->out_edges()) {
+ // Control outputs of switch nodes (which are unconditionally executed if
+ // the switch is) are not removed as they need not be part of a
+ // conditional.
+ if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
+ }
+ deleted[s_id] = true;
+ graph_->RemoveNode(s);
+ }
+
+ // All merge nodes should have been transformed at this point and we remove
+ // them from the graph here.
+ for (Node* m : merge_order) {
+ for (const Edge* e : m->out_edges()) {
+ // Similar to control outputs of switch nodes don't remove control
+ // outputs of merge nodes.
+ // TODO(jpienaar): Check cases where output edges still exist here vs
+ // being removed in AddOutputEdges.
+ if (!e->IsControlEdge()) delete_nodes.push_back(e->dst()->id());
+ }
+ deleted[m->id()] = true;
+ graph_->RemoveNode(m);
+ }
+
+ // Enqueue all the dead nodes.
+ for (Node* n : graph_->nodes()) {
+ if (state_map_.IsDead(state_map_.LookupCondId(n))) {
+ delete_nodes.push_back(n->id());
+ }
+ }
+
+ while (!delete_nodes.empty()) {
+ int d_id = delete_nodes.front();
+ delete_nodes.pop_front();
if (deleted[d_id]) continue;
Node* d = graph_->FindNodeId(d_id);
// Switch and Merge nodes could have been deleted already.
if (d == nullptr) continue;
for (const Edge* e : d->out_edges()) {
- delete_nodes_.push_back(e->dst()->id());
+ delete_nodes.push_back(e->dst()->id());
}
deleted[d_id] = true;
graph_->RemoveNode(d);
@@ -1274,7 +1307,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
}
TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order)));
- if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id");
+ if (VLOG_IS_ON(4)) DumpGraphWithCondState("id");
// Sort the merge nodes from innermost outwards.
SortMergeNodes(&merge_order);
@@ -1312,11 +1345,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract");
}
- // All remaining Switch nodes are not reachable from a Merge node and
- // removed. This is to account for dead Switch nodes.
- for (int s_id : switch_ids) delete_nodes_.push_back(s_id);
- for (Node* m : merge_order) delete_nodes_.push_back(m->id());
- DeleteReachableNodes();
+ DeleteReachableAndDeadNodes(switch_ids, merge_order);
return Status::OK();
}
@@ -1331,8 +1360,9 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) {
state_map_.AncestorStateToString(n)));
}
LOG(INFO) << "FunctionalizeControlFlow (" << name << "): "
- << dump_graph::DumpGraphToFile(absl::StrCat("functionalize_", name),
- *graph_, library_);
+ << dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_cond_", name), *graph_,
+ library_);
}
Status FunctionalizeCond::Functionalize(Graph* graph,
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
index 28301150ea..1899808940 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.h
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -91,10 +91,6 @@ class StateMap {
// Resets the AncestorId for a given node.
void ResetAncestorId(const Node* node, AncestorId id);
- // Returns the CondState for a Node.
- // REQUIRES: node has a non-empty CondState.
- const CondState& LookupState(const Node* node) const;
-
// Marks `node` as dead.
void MarkDead(const Node* node);
@@ -221,8 +217,10 @@ class FunctionalizeCond {
// nesting depth.
void SortMergeNodes(std::vector<Node*>* merge_order);
- // Deletes all nodes in/consumers of `delete_nodes_`.
- void DeleteReachableNodes();
+ // Deletes all nodes in/consumers reachable from switch/merge nodes that were
+ // extracted.
+ void DeleteReachableAndDeadNodes(const std::vector<int>& switch_ids,
+ const std::vector<Node*>& merge_order);
// Member used to unique the CondState to a unique CondId (AncestorState to a
// unique AncestorId) and keep track of CondState/CondId
@@ -232,9 +230,6 @@ class FunctionalizeCond {
// Mapping from merge nodes to predicate.
std::unordered_map<Node*, OutputTensor> merge_to_predicate_;
- // Nodes to be deleted.
- std::deque<int> delete_nodes_;
-
FunctionLibraryDefinition* library_;
Graph* graph_;
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 5932be4e52..98b333a467 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -31,11 +31,18 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
@@ -68,4 +75,165 @@ Status FunctionalizeControlFlow(Graph* graph,
return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library);
}
+Status FunctionalizeControlFlowForFunction(
+ const string& func_name, const string& new_func_name,
+ const protobuf::Map<string, tensorflow::AttrValue>& attrs,
+ FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
+ std::map<string, string>* canonicalized_name_to_new_name) {
+ // Convert the function to Graph.
+ FunctionLibraryRuntime::Handle handle;
+ TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
+ Status ret_status = Status::OK();
+ auto cleanup_handle = gtl::MakeCleanup([&]() {
+ auto s = flr->ReleaseHandle(handle);
+ if (!s.ok()) {
+ ret_status.Update(s);
+ }
+ });
+ const FunctionBody* body = flr->GetFunctionBody(handle);
+
+ // If any node has associated functions, functionalize them first.
+ // Gather nodes with associated functions first, because rewriting those nodes
+ // might involve node deletion/addition. Avoid modifying nodes while iterating
+ // it.
+ std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
+ nodes_to_associated_functions;
+ for (auto* n : body->graph->nodes()) {
+ auto associated_functions = GetAssociatedFunctions(*n, flr);
+ if (!associated_functions.empty()) {
+ nodes_to_associated_functions.push_back({n, associated_functions});
+ }
+ }
+ for (auto iter : nodes_to_associated_functions) {
+ Node* n = iter.first;
+ auto associated_functions = iter.second;
+ for (auto& associated_function : associated_functions) {
+ string name = associated_function.func_name();
+ string canonicalized_name = Canonicalize(name, AttrSlice(&attrs));
+ auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
+ string new_name;
+ if (iter != canonicalized_name_to_new_name->end()) {
+ // If we already functionalized this function, skip functionalization
+ // but still rewrite the node.
+ new_name = iter->second;
+ } else {
+ new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+ name, new_name, attrs, fld, flr, canonicalized_name_to_new_name));
+ (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ }
+ // Notice that if "n" is a function call, RewriteAssociatedFunction() will
+ // delete it and create a new node instead, making "n" an invalid pointer.
+ // That's fine because in that case, associated_functions will only have
+ // one member and the loop will only run once.
+ TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+ body->graph, n, fld, associated_function, new_name));
+ }
+ }
+
+ // Call graph optimizer. The most important optimization we need is constant
+ // folding, which will replace ops like Shape/BroadcastGradientArgs with
+ // constant shape input. Without this optimization, those ops might become
+ // dynamic input for then/else body function and XLA will complain that input
+ // is not compile time constant. We enable function inlining as well, because
+ // otherwise we won't be able to infer shape for any node depending on
+ // function call nodes.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_opt_", func_name),
+ *body->graph, fld);
+ }
+ // Optimizer accepts std::unique_ptr<Graph>* as input and might change
+ // underlying pointer, thus we create a new Graph and copy from body->graph.
+ std::unique_ptr<Graph> optimized_graph(new Graph(fld));
+ CopyGraph(*body->graph, optimized_graph.get());
+ OptimizerOptions opts;
+ opts.set_opt_level(OptimizerOptions::L0);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
+ GraphOptimizer optimizer(opts);
+ optimizer.Optimize(flr, flr->env(),
+ /*device=*/nullptr, &optimized_graph,
+ /*shape_map=*/nullptr);
+
+ // Functionalize the function body.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+ *optimized_graph, fld);
+ }
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
+ *optimized_graph, fld);
+ }
+ FunctionDef functionalized_fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
+ &functionalized_fdef));
+
+ // Add rewritten FunctionDef into library.
+ if (func_name == new_func_name) {
+ VLOG(2) << "Replacing function " << func_name;
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(new_func_name, functionalized_fdef));
+ } else {
+ VLOG(2) << "Adding function " << new_func_name;
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ }
+
+ return ret_status;
+}
+
+Status FunctionalizeControlFlowPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile("functionalize_control_flow_before", *graph,
+ options.flib_def);
+ }
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
+ new ProcessFunctionLibraryRuntime(
+ /*device_mgr=*/nullptr, options.session_options->env,
+ TF_GRAPH_DEF_VERSION, options.flib_def, OptimizerOptions()));
+ FunctionLibraryRuntime* flr =
+ pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
+
+ // Find XLA compile ops and its corresponding FunctionDef.
+ static std::map<string, string>* kNodeTypeToFunctionAttrMapping =
+ new std::map<string, string>{
+ {"TPUCompile", "function"},
+ {"XlaLaunch", "function"},
+ };
+ std::map<string, string> canonicalized_name_to_new_name;
+ for (Node* n : graph->nodes()) {
+ auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
+ if (it == kNodeTypeToFunctionAttrMapping->end()) {
+ continue;
+ }
+ const string func_attr = it->second;
+ if (kNodeTypeToFunctionAttrMapping->find(n->type_string()) !=
+ kNodeTypeToFunctionAttrMapping->end()) {
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), func_attr, &func));
+ VLOG(2) << "Graph has node " << n->type_string()
+ << ". Corresponding function: " << func.name();
+ string new_func_name = options.flib_def->UniqueFunctionName(
+ absl::StrCat(func.name(), "_f15n_"));
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
+ func.name(), new_func_name, func.attr(), options.flib_def, flr,
+ &canonicalized_name_to_new_name));
+ n->ClearAttr(func_attr);
+ func.set_name(new_func_name);
+ n->AddAttr(func_attr, func);
+ }
+ }
+
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile("functionalize_control_flow_after", *graph,
+ options.flib_def);
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index 55600f2a8b..ba99205640 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@@ -32,6 +33,14 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
Graph* graph,
FunctionLibraryDefinition* library);
+// This pass looks at the graph and all associated FunctionDefs, and turns
+// traditional control flow structure (Switch/Merge/etc.) into functional
+// control flow structure (If/While).
+class FunctionalizeControlFlowPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
new file mode 100644
index 0000000000..a10a9d0499
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_pass_registration.cc
@@ -0,0 +1,25 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
+
+namespace tensorflow {
+
+// This pass is required for some AOT backends and all JIT backends, so this
+// file exists as a separate lib and will be linked to both AOT and JIT.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 27,
+ FunctionalizeControlFlowPass);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index c068a4110c..c3841f996f 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
@@ -112,16 +113,12 @@ TEST(FunctionalizeControlFlow, Conditional) {
auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
- auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
- std::initializer_list<Input>{less, y, x}, then_fn,
- else_fn, {DT_INT32});
+ auto if_op = ops::If(scope.WithOpName(op_name), less,
+ std::initializer_list<Input>{less, y, x}, {DT_INT32},
+ then_fn, else_fn);
auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
- // TODO(jpienaar): Create wrapper for IfOp.
- for (NodeDef& n : *expected.mutable_node()) {
- if (n.op() == "XlaIf") n.set_op("If");
- }
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
@@ -177,7 +174,7 @@ TEST(FunctionalizeControlFlow, Conditional) {
Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
NameAttrList* body) {
for (const NodeDef& node : graph.node()) {
- if (node.op() == "XlaWhile") {
+ if (node.op() == "While") {
const NameAttrList* result;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
*cond = *result;
@@ -186,7 +183,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
return Status::OK();
}
}
- return errors::NotFound("No XlaWhile node found in graph");
+ return errors::NotFound("No While node found in graph");
}
// Graph:
@@ -255,8 +252,8 @@ TEST(FunctionalizeControlFlow, OneLoopVar) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
@@ -392,8 +389,8 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
GraphDef expected;
TF_ASSERT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
@@ -483,8 +480,8 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
@@ -625,8 +622,8 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) {
auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{x, y}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{x, y}, cond_fn, body_fn);
auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
GraphDef expected;
@@ -864,9 +861,9 @@ TEST(FunctionalizeControlFlow, Complex) {
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
- auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
- std::initializer_list<Input>{zero, y, x, var},
- outer_cond_fn, outer_body_fn);
+ auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
+ std::initializer_list<Input>{zero, y, x, var},
+ outer_cond_fn, outer_body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
@@ -921,9 +918,9 @@ TEST(FunctionalizeControlFlow, Complex) {
auto one_j = ops::Const<int32>(
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
auto while_op =
- ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
- std::initializer_list<Input>{one_j, arg1, arg2, arg3},
- inner_cond_fn, inner_body_fn);
+ ops::While(scope.WithOpName("outer/LoopCond_1"),
+ std::initializer_list<Input>{one_j, arg1, arg2, arg3},
+ inner_cond_fn, inner_body_fn);
auto one_outer = ops::Const<int32>(
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
index 7f45e3bffa..7c3ad448ef 100644
--- a/tensorflow/compiler/tf2xla/functionalize_while.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
@@ -473,12 +475,19 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
}
}
- // Builds the condition and body functions.
+ // Builds the condition and body functions. Notice that we call
+ // FunctionalizeCond() on cond_graph and body_graph because we might have
+ // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
+ // before they are encapsulated in FunctionDef.
std::unique_ptr<Graph> cond_graph;
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
+ FixupSourceAndSinkEdges(cond_graph.get());
+ TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library));
DataTypeVector arg_types;
std::unique_ptr<Graph> body_graph;
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
+ FixupSourceAndSinkEdges(body_graph.get());
+ TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library));
VLOG(2) << "Frame " << frame->name << " condition: "
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
@@ -510,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// Builds a While operator.
NodeDef while_def;
- NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
+ NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
@@ -653,9 +662,9 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
// There should be no cycle at this point, since while loops have been removed
// from graph.
- // Check that the newly added XlaWhile nodes don't feed into themselves.
+ // Check that the newly added While nodes don't feed into themselves.
for (const Node* node : graph->op_nodes()) {
- if (node->def().op() == "XlaWhile") {
+ if (node->def().op() == "While") {
TF_RETURN_WITH_CONTEXT_IF_ERROR(
CheckNodeNotInCycle(node, graph->num_node_ids()),
"Functionalizing loop failed.");
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 82e9eef005..c019a28e89 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h
index ab7cac7100..e9f02201cf 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.h
+++ b/tensorflow/compiler/tf2xla/graph_compiler.h
@@ -55,17 +55,17 @@ namespace tensorflow {
// op registration infrastructure instead of FunctionLibraryRuntime.
class GraphCompiler {
public:
- GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device,
- Graph* graph, FunctionLibraryRuntime* flib,
+ GraphCompiler(XlaCompilationDevice* device, Graph* graph,
+ FunctionLibraryRuntime* flib,
ScopedStepContainer* step_container)
- : xla_context_(xla_context),
- device_(device),
+ : device_(device),
graph_(graph),
flib_(flib),
step_container_(step_container) {}
- // Compiles the graph. The results are written in `xla_context` that is passed
- // into the compiler.
+ // Compiles the graph. The results are written in xla_context stored in the
+ // resource_manager of the 'XlaCompilationDevice' that's passed into the
+ // constructor.
Status Compile();
private:
@@ -82,7 +82,6 @@ class GraphCompiler {
// using `compiler_`.
Status CompileFunctionalNode(Node* n, OpKernelContext* op_context);
- XlaContext* xla_context_;
XlaCompilationDevice* device_;
Graph* graph_;
FunctionLibraryRuntime* flib_;
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 46794f7b50..3e823254d3 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -113,6 +113,7 @@ tf_kernel_library(
"shape_util.h",
],
deps = [
+ ":conv_op_helpers",
":if_op",
":while_op",
"//tensorflow/compiler/tf2xla:common",
@@ -172,6 +173,27 @@ tf_kernel_library(
],
)
+cc_library(
+ name = "conv_op_helpers",
+ srcs = ["conv_op_helpers.cc"],
+ hdrs = ["conv_op_helpers.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/lib:constants",
+ "//tensorflow/compiler/xla/client/lib:numeric",
+ "//tensorflow/core:framework",
+ "//tensorflow/core/kernels:bounds_check",
+ "//tensorflow/core/kernels:conv_ops",
+ "//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
tf_kernel_library(
name = "while_op",
srcs = ["while_op.cc"],
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index df17da4c1c..66676452d0 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions));
XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
+// Implementation of DivNoNan. Pseudo-code:
+// if (y == 0) {
+// return 0
+// } else {
+// return x / y;
+// }
+static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
+ xla::XlaOp y, const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto y_equals_0 = xla::Eq(y, zero);
+ auto zeros = xla::ZerosLike(x);
+ auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y));
+ return result;
+}
+XLA_MAKE_BINARY(DivNoNan,
+ DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
// Implementation of FloorDiv. Pseudo-code:
// if ((x < 0) != (y < 0)) {
// T abs_x = std::abs(x);
@@ -66,6 +85,9 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ if (DataTypeIsUnsigned(dtype)) {
+ return xla::Div(x, y);
+ }
auto zero = XlaHelpers::Zero(b, dtype);
auto one = XlaHelpers::One(b, dtype);
auto different_sign = xla::Ne(xla::Lt(x, zero), xla::Lt(y, zero));
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
index f410605104..0ae23aa6df 100644
--- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -37,6 +37,16 @@ limitations under the License.
namespace tensorflow {
namespace {
+// Used to determine the number of Tensors allowed in a Concat op to prevent
+// going over the max gpu parameter memory size. This is an issue because concat
+// is variadic and can have an unlimited number of arguments when called.
+// Concat ops with more Tensors than this will be split into multiple concat
+// ops.
+//
+// TODO(b/112613927): Remove the logic here and put it properly in an HLO pass
+// along with boxing large numbers of parameters.
+constexpr int64 kMaxConcatArgsPerOp = 500;
+
// --------------------------------------------------------------------------
class ConcatBaseOp : public XlaOpKernel {
public:
@@ -74,6 +84,7 @@ class ConcatBaseOp : public XlaOpKernel {
// Make a vector holding the XlaOp for each of the inputs that has non-zero
// elements.
std::vector<xla::XlaOp> input_data;
+ std::vector<xla::XlaOp> partial_concats;
int output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) {
@@ -94,10 +105,30 @@ class ConcatBaseOp : public XlaOpKernel {
input_data.push_back(handle);
}
output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
+
+ // Concat is associative, so it can be split into many operations when too
+ // many arguments are in a single op. This is a temporary workaround for
+ // b/112613927 where too many parameters in an XlaLaunchOp later result in
+ // too many parameters to a single GPU kernel.
+ if (i && i % kMaxConcatArgsPerOp == 0) {
+ partial_concats.push_back(
+ xla::ConcatInDim(ctx->builder(), input_data, axis));
+ input_data.clear();
+ }
}
+ // Add any inputs that have not been put into another concat yet.
+ partial_concats.insert(partial_concats.end(), input_data.begin(),
+ input_data.end());
VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
- ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis));
+ // Don't add an additional "identity" concatenate for better readibility of
+ // IR.
+ if (partial_concats.size() == 1) {
+ ctx->SetOutput(0, partial_concats.front());
+ } else {
+ ctx->SetOutput(0,
+ xla::ConcatInDim(ctx->builder(), partial_concats, axis));
+ }
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
new file mode 100644
index 0000000000..c9a1be4940
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -0,0 +1,509 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for 2D convolution.
+
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+// Returns the expanded size of a filter used for depthwise convolution.
+// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
+xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) {
+ int num_dims = shape.dimensions_size();
+ CHECK_GE(num_dims, 2); // Crash OK
+ xla::Shape expanded_shape = shape;
+ expanded_shape.set_dimensions(
+ num_dims - 1,
+ shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1));
+ return expanded_shape;
+}
+
+// Create a mask for depthwise convolution that will make a normal convolution
+// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
+// depthwise filter this returns a [2, 2, 3, 6] tensor
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// The first step is to create a one tensor, A, that is [3]
+// 0 1 2
+//
+// and another tensor, B, that is [3 * 2]
+// 0 1 2 3 4 5
+//
+// and divide B it by 2 to get
+// 0 0 1 1 2 2
+//
+// then we broadcast the B to [2, 2, 3, 3 * 2]
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// Finally compare A and broadcasted B in dimension 2 amd return the result at
+// the beginning of the comment.
+xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
+ xla::XlaBuilder* builder) {
+ xla::Shape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ int64 depthwise_multiplier =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 1);
+ int64 input_feature =
+ filter_shape.dimensions(filter_shape.dimensions_size() - 2);
+
+ // Create a M sized linspace and an M*N sized linspace that will be
+ // broadcasted into perpendicular dimensions and compared.
+ xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
+ xla::XlaOp expanded_feature_iota =
+ xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
+
+ // Divide the M*N sized linspace by the depthwise_multiplier to create
+ // [0 0 1 1 2 2] in the example in the function comment.
+ expanded_feature_iota =
+ xla::Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
+
+ // Broadcast the N*M linspace to [H, W, ..., M, M*N].
+ std::vector<int64> expanded_feature_broadcast_dims(
+ expanded_filter_shape.dimensions().begin(),
+ expanded_filter_shape.dimensions().end());
+ expanded_feature_broadcast_dims.pop_back();
+ auto broadcasted_expanded_feature_iota =
+ xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
+
+ // Compare the broadcasted linspace to the input feature linspace in the
+ // input feature dimension to create a diagonal predicate.
+ return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dimensions_size() - 2});
+}
+
+// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
+// build a depthwise convolution.
+xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter) {
+ int64 input_feature_dim = filter_shape.dimensions_size() - 2;
+ int64 output_feature_dim = filter_shape.dimensions_size() - 1;
+ int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim);
+ int64 input_feature = filter_shape.dimensions(input_feature_dim);
+
+ // Create a [H, W, ..., 1, N*M] reshape of the filter.
+ xla::Shape implicit_broadcast_filter_shape = filter_shape;
+ implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1);
+ implicit_broadcast_filter_shape.set_dimensions(
+ output_feature_dim, depthwise_multiplier * input_feature);
+ return xla::Reshape(
+ filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions()));
+}
+
+// Reduces the results of the convolution with an expanded filter to the
+// non-expanded filter.
+xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape,
+ const xla::XlaOp& filter_backprop,
+ xla::XlaBuilder* builder) {
+ auto masked_expanded_filter =
+ xla::Select(CreateExpandedFilterMask(filter_shape, builder),
+ filter_backprop, xla::ZerosLike(filter_backprop));
+
+ auto elem_type = filter_shape.element_type();
+ return xla::Reshape(
+ // This reduce does not need inputs to be converted with
+ // XlaHelpers::SumAccumulationType() since the select above guarantees
+ // that only one element is non zero, so there cannot be accumulated
+ // precision error.
+ xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type),
+ CreateScalarAddComputation(elem_type, builder),
+ {filter_shape.dimensions_size() - 2}),
+ xla::AsInt64Slice(filter_shape.dimensions()));
+}
+
+// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA
+// convolutions (as currently implemented).
+Status CheckConvAttrs(const ConvOpAttrs& attrs) {
+ const int num_dims = attrs.num_spatial_dims + 2;
+ if (attrs.strides.size() != num_dims) {
+ return errors::InvalidArgument("Sliding window strides field must specify ",
+ num_dims, " dimensions");
+ }
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+ if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not yet support strides in the batch and "
+ "depth dimensions.");
+ }
+ if (attrs.dilations.size() != num_dims) {
+ return errors::InvalidArgument("Dilations field must specify ", num_dims,
+ " dimensions");
+ }
+ if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) {
+ return errors::Unimplemented(
+ "Current implementation does not support dilations in the batch and "
+ "depth dimensions.");
+ }
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ if (attrs.dilations[input_dim] < 1) {
+ return errors::Unimplemented("Dilation values must be positive; ", i,
+ "th spatial dimension had dilation ",
+ attrs.dilations[input_dim]);
+ }
+ }
+ return Status::OK();
+}
+
+// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes
+// to TensorShapes.
+Status ConvBackpropComputeDimensionsV2XlaShapes(
+ StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
+ const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
+ absl::Span<const int32> dilations, const std::vector<int32>& strides,
+ Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
+ TensorShape input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape));
+ TF_RETURN_IF_ERROR(
+ XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
+ return ConvBackpropComputeDimensionsV2(
+ label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
+ out_backprop_tensor_shape, dilations, strides, padding, data_format,
+ dims);
+}
+
+} // anonymous namespace
+
+xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims,
+ bool depthwise,
+ OpKernelConstruction* ctx) {
+ ConvOpAttrs attrs;
+ attrs.num_spatial_dims = num_spatial_dims;
+ attrs.depthwise = depthwise;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
+
+ string data_format;
+ TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format));
+ if (!FormatFromString(data_format, &attrs.data_format)) {
+ return errors::InvalidArgument("Invalid data format: ", data_format);
+ }
+
+ return attrs;
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = conv_input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input));
+ // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth]
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+
+ // For 2D convolution, there should be 4 dimensions.
+ int num_dims = attrs.num_spatial_dims + 2;
+ if (input_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument("input must be ", num_dims, "-dimensional",
+ input_shape.DebugString());
+ }
+ if (filter_shape.dimensions_size() != num_dims) {
+ return errors::InvalidArgument(
+ "filter must be ", num_dims,
+ "-dimensional: ", filter_shape.DebugString());
+ }
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims);
+ // The 'C' dimension for input is in_depth. It must be the same as
+ // the filter's in_depth.
+ if (in_depth != input_shape.dimensions(feature_dim)) {
+ return errors::InvalidArgument(
+ "input and filter must have the same depth: ", in_depth, " vs ",
+ input_shape.dimensions(feature_dim));
+ }
+
+ if (attrs.depthwise) {
+ filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
+ }
+
+ xla::ConvolutionDimensionNumbers dims;
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+
+ dims.set_input_batch_dimension(batch_dim);
+ dims.set_output_batch_dimension(batch_dim);
+ dims.set_input_feature_dimension(feature_dim);
+ dims.set_output_feature_dimension(feature_dim);
+ dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims);
+ dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dims.add_input_spatial_dimensions(dim);
+ dims.add_kernel_spatial_dimensions(i);
+ dims.add_output_spatial_dimensions(dim);
+ window_strides[i] = attrs.strides.at(dim);
+ rhs_dilation[i] = attrs.dilations.at(dim);
+
+ int64 unused_output_size;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
+ input_shape.dimensions(dim), filter_shape.dimensions(i),
+ rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size,
+ &padding[i].first, &padding[i].second));
+ }
+
+ return xla::ConvGeneralDilated(
+ conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation,
+ dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ int num_dims = attrs.num_spatial_dims + 2;
+ int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ auto* builder = filter.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(out_backprop));
+
+ xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
+ out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
+ attrs.data_format, &dims));
+
+ // The input gradients are computed by a convolution of the output
+ // gradients and the filter, with some appropriate padding. See the
+ // comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(batch_dim);
+ dnums.set_output_batch_dimension(batch_dim);
+ dnums.set_input_feature_dimension(feature_dim);
+ dnums.set_output_feature_dimension(feature_dim);
+
+ // TF filter shape is [ H, W, ..., inC, outC ]
+ // Transpose the input and output features for computing the gradient.
+ dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1);
+ dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims);
+
+ std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims);
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> lhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(i);
+ dnums.add_output_spatial_dimensions(dim);
+
+ kernel_spatial_dims[i] = i;
+ padding[i] = {dims.spatial_dims[i].pad_before,
+ dims.spatial_dims[i].pad_after};
+ lhs_dilation[i] = dims.spatial_dims[i].stride;
+ rhs_dilation[i] = attrs.dilations[dim];
+ }
+
+ // Mirror the filter in the spatial dimensions.
+ xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
+
+ // activation gradients
+ // = gradients (with padding and dilation) <conv> mirrored_weights
+ return xla::ConvGeneralDilated(
+ out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
+ lhs_dilation, rhs_dilation, dnums,
+ /*feature_group_count=*/
+ attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) /
+ filter_shape.dimensions(attrs.num_spatial_dims + 1)
+ : 1);
+}
+
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs) {
+ TF_RETURN_IF_ERROR(CheckConvAttrs(attrs));
+
+ auto* builder = activations.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape activations_shape,
+ builder->GetShape(activations));
+ TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape,
+ builder->GetShape(gradients));
+ const xla::Shape expanded_filter_shape =
+ attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
+ : filter_shape;
+
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ ConvBackpropDimensions dims;
+ TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
+ type_string, attrs.num_spatial_dims, activations_shape,
+ expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
+ attrs.padding, attrs.data_format, &dims));
+
+ // The filter gradients are computed by a convolution of the input
+ // activations and the output gradients, with some appropriate padding.
+ // See the comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+
+ // The activations (inputs) form the LHS of the convolution.
+ // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
+ // For the gradient computation, we flip the roles of the batch and
+ // feature dimensions.
+ // Each spatial entry has size in_depth * batch
+
+ // The last two dimensions of the filter are the input and output shapes.
+ int num_dims = attrs.num_spatial_dims + 2;
+ int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format);
+ int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format);
+
+ // Swap n_dim and c_dim in the activations.
+ dnums.set_input_batch_dimension(c_dim);
+ dnums.set_input_feature_dimension(n_dim);
+
+ // The gradients become the RHS of the convolution.
+ // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
+ // where the batch becomes the input feature for the convolution.
+ dnums.set_kernel_input_feature_dimension(n_dim);
+ dnums.set_kernel_output_feature_dimension(c_dim);
+
+ std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims);
+ std::vector<int64> rhs_dilation(attrs.num_spatial_dims);
+ std::vector<int64> window_strides(attrs.num_spatial_dims);
+ std::vector<int64> ones(attrs.num_spatial_dims, 1);
+
+ // Tensorflow filter shape is [ H, W, ..., inC, outC ].
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ dnums.add_output_spatial_dimensions(i);
+ }
+ dnums.set_output_batch_dimension(attrs.num_spatial_dims);
+ dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1);
+
+ for (int i = 0; i < attrs.num_spatial_dims; ++i) {
+ int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
+ dnums.add_input_spatial_dimensions(dim);
+ dnums.add_kernel_spatial_dimensions(dim);
+
+ // We will also need to pad the input with zeros such that after the
+ // convolution, we get the right size for the filter.
+ // The padded_in_rows should be such that when we convolve this with the
+ // expanded_out_rows as a filter, we should get filter_rows back.
+ //
+ const int64 padded_in_size =
+ dims.spatial_dims[i].expanded_output_size +
+ (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim];
+
+ // However it can be smaller than input_rows: in this
+ // case it means some of the inputs are not used.
+ //
+ // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
+ //
+ // INPUT = [ A B C ]
+ //
+ // FILTER = [ x y ]
+ //
+ // and the output will only have one column: a = A * x + B * y
+ //
+ // and input "C" is not used at all.
+ //
+ // We apply negative padding in this case.
+ const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
+
+ // + For the VALID padding, we don't pad anything on the top/left side
+ // and pad the bottom/right side with the remaining space.
+ // + For the SAME padding, we pad top/left side the same as bottom/right
+ // side.
+ //
+ // In addition, if the padded input size is smaller than the input size,
+ // we need to ignore some training elements of the input. We do this by
+ // applying negative padding on the right/bottom.
+ const int64 pad_before =
+ attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
+
+ padding[i] = {pad_before, pad_total - pad_before};
+ rhs_dilation[i] = dims.spatial_dims[i].stride;
+ window_strides[i] = attrs.dilations[dim];
+ }
+
+ // Besides padding the input, we will also expand output_rows to
+ // expanded_out_rows = (output_rows - 1) * stride + 1
+ // with zeros in between:
+ //
+ // a . . . b . . . c . . . d . . . e
+ //
+ // This is done by specifying the window dilation factors in the
+ // convolution HLO below.
+ auto filter_backprop =
+ xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
+ /*lhs_dilation=*/ones, rhs_dilation, dnums);
+
+ if (attrs.depthwise) {
+ filter_backprop = ContractFilterForDepthwiseBackprop(
+ filter_shape, filter_backprop, activations.builder());
+ }
+
+ return filter_backprop;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
new file mode 100644
index 0000000000..6e1b70a478
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+// This header exposes utilities for translating TensorFlow convolution ops into
+// XLA ops.
+//
+// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g.
+// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in
+// this header to implement a new and exciting convolution op, for example a
+// fused TensorFlow op that contains a convolution and other things.
+
+namespace tensorflow {
+
+// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA
+// convolution.
+struct ConvOpAttrs {
+ // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`.
+ static xla::StatusOr<ConvOpAttrs> Create(int num_spatial_dims, bool depthwise,
+ OpKernelConstruction* ctx);
+
+ bool depthwise;
+ int num_spatial_dims;
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Creates a new XLA forward or backward convolution with the given inputs and
+// attributes.
+xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string,
+ xla::XlaOp conv_input,
+ xla::XlaOp filter,
+ const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp(
+ StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter,
+ xla::XlaOp out_backprop, const ConvOpAttrs& attrs);
+xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp(
+ StringPiece type_string, xla::XlaOp activations,
+ const xla::Shape& filter_shape, xla::XlaOp gradients,
+ const ConvOpAttrs& attrs);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index 674720e22f..cd7c820be0 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -15,12 +15,17 @@ limitations under the License.
// XLA-specific Ops for 2D convolution.
+#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -33,250 +38,28 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
-
namespace {
-// Returns the expanded size of a filter used for depthwise convolution.
-// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
-TensorShape ExpandedFilterShapeForDepthwiseConvolution(
- const TensorShape& shape) {
- int num_dims = shape.dims();
- CHECK_GE(num_dims, 2);
- TensorShape expanded_shape = shape;
- expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) *
- shape.dim_size(num_dims - 1));
- return expanded_shape;
-}
-
-// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
-xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype,
- xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- return xla::Broadcast(XlaHelpers::Zero(builder, dtype),
- expanded_filter_shape.dim_sizes());
-}
-
-// Create a mask for depthwise convolution that will make a normal convolution
-// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
-// depthwise filter this returns a [2, 2, 3, 6] tensor
-// 1 1 0 0 0 0 1 1 0 0 0 0
-// 0 0 1 1 0 0 0 0 1 1 0 0
-// 0 0 0 0 1 1 0 0 0 0 1 1
-//
-// 1 1 0 0 0 0 1 1 0 0 0 0
-// 0 0 1 1 0 0 0 0 1 1 0 0
-// 0 0 0 0 1 1 0 0 0 0 1 1
-//
-// The first step is to create a one tensor, A, that is [3]
-// 0 1 2
-//
-// and another tensor, B, that is [3 * 2]
-// 0 1 2 3 4 5
-//
-// and divide B it by 2 to get
-// 0 0 1 1 2 2
-//
-// then we broadcast the B to [2, 2, 3, 3 * 2]
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-//
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-//
-// Finally compare A and broadcasted B in dimension 2 amd return the result at
-// the beginning of the comment.
-xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape,
- xla::XlaBuilder* builder) {
- TensorShape expanded_filter_shape =
- ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
- int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
-
- // Create a M sized linspace and an M*N sized linspace that will be
- // broadcasted into perpendicular dimensions and compared.
- xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
- xla::XlaOp expanded_feature_iota =
- xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
-
- // Divide the M*N sized linspace by the depthwise_multiplier to create
- // [0 0 1 1 2 2] in the example in the function comment.
- expanded_feature_iota =
- xla::Div(expanded_feature_iota,
- XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
- depthwise_multiplier));
-
- // Broadcast the N*M linspace to [H, W, ..., M, M*N].
- auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
- expanded_feature_broadcast_dims.pop_back();
- auto broadcasted_expanded_feature_iota =
- xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
-
- // Compare the broadcasted linspace to the input feature linspace in the
- // input feature dimension to create a diagonal predicate.
- return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
- {expanded_filter_shape.dims() - 2});
-}
-
-// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
-// build a depthwise convolution.
-xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape,
- const xla::XlaOp& filter) {
- int64 input_feature_dim = filter_shape.dims() - 2;
- int64 output_feature_dim = filter_shape.dims() - 1;
- int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim);
- int64 input_feature = filter_shape.dim_size(input_feature_dim);
-
- // Create a [H, W, ..., 1, N*M] reshape of the filter.
- TensorShape implicit_broadcast_filter_shape = filter_shape;
- implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1);
- implicit_broadcast_filter_shape.set_dim(output_feature_dim,
- depthwise_multiplier * input_feature);
- return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
-}
-
-// Reduces the results of the convolution with an expanded filter to the
-// non-expanded filter.
-xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx,
- const TensorShape& filter_shape,
- DataType dtype,
- const xla::XlaOp& filter_backprop,
- xla::XlaBuilder* builder) {
- auto masked_expanded_filter = xla::Select(
- CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
- CreateExpandedZero(filter_shape, dtype, builder));
- return xla::Reshape(
- // This reduce does not need inputs to be converted with
- // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with
- // ExpandedZero guarantees that only one element is non zero, so there
- // cannot be accumulated precision error.
- xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
- *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}),
- filter_shape.dim_sizes());
-}
-
class ConvOp : public XlaOpKernel {
public:
explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
-
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, strides_.size() == num_dims(),
- errors::InvalidArgument("Sliding window strides field must "
- "specify ",
- num_dims(), " dimensions"));
- int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- OP_REQUIRES(
- ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- const TensorShape input_shape = ctx->InputShape(0);
- // Input filter is of the following dimensions:
- // [ filter_rows, filter_cols, ..., in_depth, out_depth]
- const TensorShape filter_shape = ctx->InputShape(1);
-
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(
- ctx, input_shape.dims() == num_dims(),
- errors::InvalidArgument("input must be ", num_dims(), "-dimensional",
- input_shape.DebugString()));
- OP_REQUIRES(
- ctx, filter_shape.dims() == num_dims(),
- errors::InvalidArgument("filter must be ", num_dims(),
- "-dimensional: ", filter_shape.DebugString()));
-
- // The last two dimension of the filter are the input and output shapes.
- const int64 in_depth = filter_shape.dim_size(num_spatial_dims_);
-
- // The 'C' dimension for input is in_depth. It must be the same as
- // the filter's in_depth.
- OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", in_depth,
- " vs ", input_shape.dim_size(feature_dim)));
-
- xla::XlaOp filter = ctx->Input(1);
- if (depthwise_) {
- filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter);
- }
-
- xla::ConvolutionDimensionNumbers dims;
- std::vector<int64> window_strides(num_spatial_dims_);
- std::vector<int64> lhs_dilation(num_spatial_dims_, 1);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
-
- dims.set_input_batch_dimension(batch_dim);
- dims.set_output_batch_dimension(batch_dim);
- dims.set_input_feature_dimension(feature_dim);
- dims.set_output_feature_dimension(feature_dim);
- dims.set_kernel_input_feature_dimension(num_spatial_dims_);
- dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1);
-
- for (int i = 0; i < num_spatial_dims_; ++i) {
- const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dims.add_input_spatial_dimensions(dim);
- dims.add_kernel_spatial_dimensions(i);
- dims.add_output_spatial_dimensions(dim);
- window_strides[i] = strides_.at(dim);
- rhs_dilation[i] = dilations_.at(dim);
-
- int64 unused_output_size;
- OP_REQUIRES_OK(
- ctx, GetWindowedOutputSizeVerboseV2(
- input_shape.dim_size(dim), filter_shape.dim_size(i),
- rhs_dilation[i], window_strides[i], padding_,
- &unused_output_size, &padding[i].first, &padding[i].second));
- }
-
- xla::XlaOp conv = xla::ConvGeneralDilated(
- ctx->Input(0), filter, window_strides, padding, lhs_dilation,
- rhs_dilation, dims,
- /*feature_group_count=*/depthwise_ ? in_depth : 1);
- ctx->SetOutput(0, conv);
+ xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp(
+ ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_);
+ OP_REQUIRES_OK(ctx, conv.status());
+ ctx->SetOutput(0, conv.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvOp);
@@ -308,124 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel {
public:
explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- OP_REQUIRES(ctx, strides_.size() == num_dims(),
- errors::InvalidArgument("Sliding window strides field must "
- "specify ",
- num_dims(), " dimensions"));
- int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
- OP_REQUIRES(
- ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- TensorShape input_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
-
- const TensorShape filter_shape = ctx->InputShape(1);
- const TensorShape out_backprop_shape = ctx->InputShape(2);
-
- const TensorShape expanded_filter_shape =
- depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
- : filter_shape;
- // Reuse dimension computation logic from conv_grad_ops.cc.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx,
- ConvBackpropComputeDimensionsV2(
- type_string(), num_spatial_dims_, input_shape,
- expanded_filter_shape, out_backprop_shape, dilations_,
- strides_, padding_, data_format_, &dims));
-
- auto filter = ctx->Input(1);
- auto out_backprop = ctx->Input(2);
-
- // The input gradients are computed by a convolution of the output
- // gradients and the filter, with some appropriate padding. See the
- // comment at the top of conv_grad_ops.h for details.
-
- xla::ConvolutionDimensionNumbers dnums;
- dnums.set_input_batch_dimension(batch_dim);
- dnums.set_output_batch_dimension(batch_dim);
- dnums.set_input_feature_dimension(feature_dim);
- dnums.set_output_feature_dimension(feature_dim);
-
- // TF filter shape is [ H, W, ..., inC, outC ]
- // Transpose the input and output features for computing the gradient.
- dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1);
- dnums.set_kernel_output_feature_dimension(num_spatial_dims_);
-
- std::vector<int64> kernel_spatial_dims(num_spatial_dims_);
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
- std::vector<int64> lhs_dilation(num_spatial_dims_);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<int64> ones(num_spatial_dims_, 1);
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_input_spatial_dimensions(dim);
- dnums.add_kernel_spatial_dimensions(i);
- dnums.add_output_spatial_dimensions(dim);
-
- kernel_spatial_dims[i] = i;
- padding[i] = {dims.spatial_dims[i].pad_before,
- dims.spatial_dims[i].pad_after};
- lhs_dilation[i] = dims.spatial_dims[i].stride;
- rhs_dilation[i] = dilations_[dim];
- }
-
- // Mirror the filter in the spatial dimensions.
- xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims);
-
- // activation gradients
- // = gradients (with padding and dilation) <conv> mirrored_weights
- xla::XlaOp in_backprop = xla::ConvGeneralDilated(
- out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
- lhs_dilation, rhs_dilation, dnums,
- /*feature_group_count=*/
- depthwise_ ? out_backprop_shape.dim_size(feature_dim) /
- filter_shape.dim_size(num_spatial_dims_ + 1)
- : 1);
-
- ctx->SetOutput(0, in_backprop);
+ TensorShape input_tensor_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape));
+ xla::Shape input_shape =
+ TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape);
+
+ xla::StatusOr<xla::XlaOp> in_backprop =
+ MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape,
+ ctx->Input(1), ctx->Input(2), attrs_);
+ OP_REQUIRES_OK(ctx, in_backprop.status());
+ ctx->SetOutput(0, in_backprop.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp);
@@ -462,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel {
public:
explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
- : XlaOpKernel(ctx),
- num_spatial_dims_(num_spatial_dims),
- depthwise_(depthwise) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
- string data_format;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
- OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : XlaOpKernel(ctx) {
+ xla::StatusOr<ConvOpAttrs> attrs =
+ ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx);
+ OP_REQUIRES_OK(ctx, attrs.status());
+ attrs_ = attrs.ValueOrDie();
}
- int num_dims() const { return num_spatial_dims_ + 2; }
-
void Compile(XlaOpKernelContext* ctx) override {
- const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_);
- const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_);
-
- OP_REQUIRES(
- ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1),
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
-
- OP_REQUIRES(ctx, dilations_.size() == num_dims(),
- errors::InvalidArgument("Dilations field must "
- "specify ",
- num_dims(), " dimensions"));
- OP_REQUIRES(
- ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1,
- errors::Unimplemented("Current implementation does not support "
- "dilations in the batch and depth dimensions."));
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
- errors::Unimplemented("Dilation values must be positive; ", i,
- "th spatial dimension had dilation ",
- dilations_[input_dim]));
- }
-
- const TensorShape activations_shape = ctx->InputShape(0);
- TensorShape filter_shape;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
- const TensorShape out_backprop_shape = ctx->InputShape(2);
-
- const TensorShape expanded_filter_shape =
- depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
- : filter_shape;
-
- // Reuse dimension computation logic from conv_grad_ops.cc.
- ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx,
- ConvBackpropComputeDimensionsV2(
- type_string(), num_spatial_dims_, activations_shape,
- expanded_filter_shape, out_backprop_shape, dilations_,
- strides_, padding_, data_format_, &dims));
-
- xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp activations = ctx->Input(0);
- xla::XlaOp gradients = ctx->Input(2);
-
- // The filter gradients are computed by a convolution of the input
- // activations and the output gradients, with some appropriate padding.
- // See the comment at the top of conv_grad_ops.h for details.
-
- xla::ConvolutionDimensionNumbers dnums;
-
- // The activations (inputs) form the LHS of the convolution.
- // Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
- // For the gradient computation, we flip the roles of the batch and
- // feature dimensions.
- // Each spatial entry has size in_depth * batch
-
- // Swap n_dim and c_dim in the activations.
- dnums.set_input_batch_dimension(c_dim);
- dnums.set_input_feature_dimension(n_dim);
-
- // The gradients become the RHS of the convolution.
- // The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
- // where the batch becomes the input feature for the convolution.
- dnums.set_kernel_input_feature_dimension(n_dim);
- dnums.set_kernel_output_feature_dimension(c_dim);
-
- std::vector<std::pair<int64, int64>> padding(num_spatial_dims_);
- std::vector<int64> rhs_dilation(num_spatial_dims_);
- std::vector<int64> window_strides(num_spatial_dims_);
- std::vector<int64> ones(num_spatial_dims_, 1);
-
- // Tensorflow filter shape is [ H, W, ..., inC, outC ].
- for (int i = 0; i < num_spatial_dims_; ++i) {
- dnums.add_output_spatial_dimensions(i);
- }
- dnums.set_output_batch_dimension(num_spatial_dims_);
- dnums.set_output_feature_dimension(num_spatial_dims_ + 1);
-
- for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
- dnums.add_input_spatial_dimensions(dim);
- dnums.add_kernel_spatial_dimensions(dim);
-
- // We will also need to pad the input with zeros such that after the
- // convolution, we get the right size for the filter.
- // The padded_in_rows should be such that when we convolve this with the
- // expanded_out_rows as a filter, we should get filter_rows back.
- //
- const int64 padded_in_size =
- dims.spatial_dims[i].expanded_output_size +
- (dims.spatial_dims[i].filter_size - 1) * dilations_[dim];
-
- // However it can be smaller than input_rows: in this
- // case it means some of the inputs are not used.
- //
- // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
- //
- // INPUT = [ A B C ]
- //
- // FILTER = [ x y ]
- //
- // and the output will only have one column: a = A * x + B * y
- //
- // and input "C" is not used at all.
- //
- // We apply negative padding in this case.
- const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
-
- // + For the VALID padding, we don't pad anything on the top/left side
- // and pad the bottom/right side with the remaining space.
- // + For the SAME padding, we pad top/left side the same as bottom/right
- // side.
- //
- // In addition, if the padded input size is smaller than the input size,
- // we need to ignore some training elements of the input. We do this by
- // applying negative padding on the right/bottom.
- const int64 pad_before =
- padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
-
- padding[i] = {pad_before, pad_total - pad_before};
- rhs_dilation[i] = dims.spatial_dims[i].stride;
- window_strides[i] = dilations_[dim];
- }
-
- // Besides padding the input, we will also expand output_rows to
- // expanded_out_rows = (output_rows - 1) * stride + 1
- // with zeros in between:
- //
- // a . . . b . . . c . . . d . . . e
- //
- // This is done by specifying the window dilation factors in the
- // convolution HLO below.
- auto filter_backprop =
- xla::ConvGeneralDilated(activations, gradients, window_strides, padding,
- /*lhs_dilation=*/ones, rhs_dilation, dnums);
-
- if (depthwise_) {
- filter_backprop = ContractFilterForDepthwiseBackprop(
- ctx, filter_shape, ctx->input_type(0), filter_backprop, b);
- }
- ctx->SetOutput(0, filter_backprop);
+ TensorShape filter_tensor_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape));
+ xla::Shape filter_shape =
+ TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape);
+
+ xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp(
+ ctx->op_kernel().type_string(), ctx->Input(0), filter_shape,
+ ctx->Input(2), attrs_);
+ OP_REQUIRES_OK(ctx, filter_backprop.status());
+ ctx->SetOutput(0, filter_backprop.ValueOrDie());
}
protected:
- const int num_spatial_dims_;
- const bool depthwise_;
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_ = FORMAT_NHWC;
+ ConvOpAttrs attrs_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
index a3389d5b90..4af1e8b44c 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc
@@ -34,15 +34,12 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* ctx) override {
- VLOG(3) << "DynamicUpdateSliceOp::Compile";
+ DataType index_type = ctx->InputType("indices");
+ CHECK(index_type == DT_INT32 || index_type == DT_INT64);
- DataType index_type = input_type(2);
- OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
- errors::InvalidArgument("index must be int32 or int64"));
-
- const TensorShape input_shape = ctx->InputShape(0);
- const TensorShape update_shape = ctx->InputShape(1);
- const TensorShape index_shape = ctx->InputShape(2);
+ const TensorShape input_shape = ctx->InputShape("input");
+ const TensorShape update_shape = ctx->InputShape("update");
+ const TensorShape index_shape = ctx->InputShape("indices");
OP_REQUIRES(
ctx,
@@ -57,13 +54,56 @@ class DynamicUpdateSliceOp : public XlaOpKernel {
input_shape.DebugString(), "; update shape is ",
update_shape.DebugString()));
- xla::XlaOp result =
- xla::DynamicUpdateSlice(ctx->Input(0), ctx->Input(1), ctx->Input(2));
+ xla::XlaOp result = xla::DynamicUpdateSlice(
+ ctx->Input("input"), ctx->Input("update"), ctx->Input("indices"));
ctx->SetOutput(0, result);
}
};
REGISTER_XLA_OP(Name("XlaDynamicUpdateSlice"), DynamicUpdateSliceOp);
+class DynamicSliceOp : public XlaOpKernel {
+ public:
+ explicit DynamicSliceOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType index_type = ctx->InputType("start_indices");
+ CHECK(index_type == DT_INT32 || index_type == DT_INT64);
+ CHECK(index_type == ctx->InputType("size_indices"));
+
+ const TensorShape input_shape = ctx->InputShape("input");
+ const TensorShape start_indices_shape = ctx->InputShape("start_indices");
+ const TensorShape size_indices_shape = ctx->InputShape("size_indices");
+
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(start_indices_shape) &&
+ start_indices_shape.num_elements() == input_shape.dims(),
+ errors::InvalidArgument(
+ "start_indices must be a vector with length equal to "
+ "input rank, but input rank is ",
+ input_shape.dims(), " and start_indices has shape ",
+ start_indices_shape.DebugString()));
+ OP_REQUIRES(ctx,
+ TensorShapeUtils::IsVector(size_indices_shape) &&
+ size_indices_shape.num_elements() == input_shape.dims(),
+ errors::InvalidArgument(
+ "size_indices must be a vector with length equal to "
+ "input rank, but input rank is ",
+ input_shape.dims(), " and size_indices has shape ",
+ size_indices_shape.DebugString()));
+
+ std::vector<int64> size_indices;
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputAsIntVector("size_indices", &size_indices));
+ xla::XlaOp result = xla::DynamicSlice(
+ ctx->Input("input"), ctx->Input("start_indices"), size_indices);
+ ctx->SetOutput(0, result);
+ }
+};
+
+REGISTER_XLA_OP(Name("XlaDynamicSlice").CompileTimeConstInput("size_indices"),
+ DynamicSliceOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index d9a0257b70..7b2bb4a7c5 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
@@ -132,14 +133,14 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size,
// If the 2D kernel would be very large, the 1D kernel can be applied once in
// each dimension due to the symmetry of the kernel along all axis to reduce the
// computational intensity.
-std::vector<float> Make1DKernel(int64 n) {
+xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) {
std::vector<float> kernel(n * 2 - 1);
for (int64 i = 0; i < n; ++i) {
float v = (i + 1.0f) / n;
kernel[i] = v;
kernel[n * 2 - 2 - i] = v;
}
- return kernel;
+ return xla::ConstantR1<float>(builder, kernel);
}
// Kernels with more than 16 spatial elements are considered intense and the
@@ -149,41 +150,26 @@ const int64 kMax2DKernelSize = 16;
xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
+ auto depthwise_kernel = xla::Broadcast(
+ xla::Zero(builder, xla::F32),
+ {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1});
- auto diag = xla::ConvertElementType(
- xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1,
- 2 * kernel_size[1] - 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
return xla::Mul(
- xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
+ xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]),
/*broadcast_dimensions=*/{1}),
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
+ Make1DKernel(builder, kernel_size[0]),
/*broadcast_dimensions=*/{0});
}
xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder,
absl::Span<const int64> kernel_size,
int64 channels, int64 dim) {
- xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels);
-
- auto diag = xla::ConvertElementType(
- xla::Eq(
- xla::Broadcast(channels_iota,
- {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
- dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
- xla::PrimitiveType::F32);
- if (dim == 1) {
- return xla::Mul(
- diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])),
- /*broadcast_dimensions=*/{1});
- }
- return xla::Mul(diag,
- xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])),
- /*broadcast_dimensions=*/{0});
+ auto depthwise_kernel =
+ xla::Broadcast(xla::Zero(builder, xla::F32),
+ {dim == 0 ? (2 * kernel_size[0] - 1) : 1,
+ dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1});
+ return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]),
+ /*broadcast_dimensions=*/{dim});
}
xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
@@ -206,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
dimension_numbers.add_input_spatial_dimensions(1 + i);
dimension_numbers.add_output_spatial_dimensions(1 + i);
@@ -285,7 +271,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
{{dims.kernel_size[0] - 1, upper_padding[0]},
{dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/dims.kernel_size,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -294,7 +281,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*padding=*/
{{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
/*lhs_dilation=*/{dims.kernel_size[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
xla::XlaOp kernel1 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1);
output = xla::ConvGeneralDilated(
@@ -302,7 +290,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder,
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
/*lhs_dilation=*/{1, dims.kernel_size[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// Add broadcasts to handle expanding from a size == 1 dimension to a
@@ -331,15 +320,15 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
xla::ConvolutionDimensionNumbers dimension_numbers;
dimension_numbers.set_input_batch_dimension(0);
dimension_numbers.set_output_batch_dimension(0);
- dimension_numbers.set_input_feature_dimension(3);
- dimension_numbers.set_output_feature_dimension(3);
+ dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
- dimension_numbers.add_input_spatial_dimensions(1 + i);
- dimension_numbers.add_output_spatial_dimensions(1 + i);
+ dimension_numbers.add_input_spatial_dimensions(i + 1);
+ dimension_numbers.add_output_spatial_dimensions(i + 1);
dimension_numbers.add_kernel_spatial_dimensions(i);
}
- dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims);
- dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
+ dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
xla::XlaOp output;
if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
xla::XlaOp kernel =
@@ -362,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
{dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/dims.stride,
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
} else {
xla::XlaOp kernel0 =
MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0);
@@ -388,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder,
/*padding=*/
{{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
/*lhs_dilation=*/{dims.stride[0], 1},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
output = xla::ConvGeneralDilated(
output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
/*padding=*/
{{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
/*lhs_dilation=*/{1, dims.stride[1]},
- /*rhs_dilation=*/{1, 1}, dimension_numbers);
+ /*rhs_dilation=*/{1, 1}, dimension_numbers,
+ /*feature_group_count=*/channels);
}
// If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 2e0a69b70e..c8a0f31a03 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel {
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
+REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
@@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel {
private:
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
+REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp);
class RankOp : public XlaOpKernel {
public:
@@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
+REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp);
class SizeOp : public XlaOpKernel {
public:
@@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
+REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp);
class ExpandDimsOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 64f2d781a6..5400e8834c 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -100,16 +100,6 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x,
precision_proto.add_operand_precision(precision);
precision_proto.add_operand_precision(precision);
- // If there are no batch dimensions, use a regular Dot.
- // TODO(b/69062148) Remove this code when Dot emitters can be passed
- // dimensions to transpose directly (i.e. without requiring a Transpose
- // HLO).
- if (batch_dimension_numbers.empty()) {
- auto lhs = transpose_x ? xla::Transpose(x, {1, 0}) : x;
- auto rhs = transpose_y ? xla::Transpose(y, {1, 0}) : y;
- return xla::Dot(lhs, rhs, &precision_proto);
- }
-
xla::DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(x_inner_dim);
dot_dnums.add_rhs_contracting_dimensions(y_inner_dim);
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
index ed452bceeb..15f4c38da2 100644
--- a/tensorflow/compiler/tf2xla/literal_util_test.cc
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -22,48 +22,61 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
+namespace {
TEST(LiteralUtil, LiteralToHostTensor) {
// int64 literal can only be converted to an int64 host tensor.
- {
- std::vector<int64> int64_values = {1, 2, 3};
- xla::Literal int64_values_literal =
- xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
- Tensor host_tensor;
- EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
- LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
- .error_message());
- EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
- LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
- .error_message());
- EXPECT_TRUE(
- LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
- test::ExpectTensorEqual<int64>(host_tensor,
- test::AsTensor<int64>(int64_values));
- }
+ std::vector<int64> int64_values = {1, 2, 3};
+ xla::Literal int64_values_literal =
+ xla::LiteralUtil::CreateR1(absl::Span<const int64>(int64_values));
+ Tensor host_tensor;
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
+ LiteralToHostTensor(int64_values_literal, DT_INT32, &host_tensor)
+ .error_message());
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type qint32",
+ LiteralToHostTensor(int64_values_literal, DT_QINT32, &host_tensor)
+ .error_message());
+ EXPECT_TRUE(
+ LiteralToHostTensor(int64_values_literal, DT_INT64, &host_tensor).ok());
+ test::ExpectTensorEqual<int64>(host_tensor,
+ test::AsTensor<int64>(int64_values));
+}
+
+template <class T>
+using LiteralUtilTest = ::testing::Test;
+using Types =
+ ::testing::Types<std::pair<int8, qint8>, std::pair<uint8, quint8>,
+ std::pair<int16, qint16>, std::pair<uint16, quint16>,
+ std::pair<int32, qint32>>;
+
+TYPED_TEST_CASE(LiteralUtilTest, Types);
+
+TYPED_TEST(LiteralUtilTest, LiteralToQuantizedHostTensor) {
+ using int_type = typename TypeParam::first_type;
+ using qint_type = typename TypeParam::second_type;
- {
- // Repeat tests with int32.
- Tensor host_tensor;
- std::vector<int32> int32_values = {10, 11};
- xla::Literal int32_values_literal =
- xla::LiteralUtil::CreateR1(absl::Span<const int32>(int32_values));
- EXPECT_TRUE(
- LiteralToHostTensor(int32_values_literal, DT_INT32, &host_tensor).ok());
- test::ExpectTensorEqual<int32>(host_tensor,
- test::AsTensor<int32>(int32_values));
+ Tensor host_tensor;
+ std::vector<int_type> int_values = {10, 11};
+ xla::Literal int_values_literal =
+ xla::LiteralUtil::CreateR1(absl::Span<const int_type>(int_values));
+ EXPECT_TRUE(LiteralToHostTensor(int_values_literal,
+ DataTypeToEnum<int_type>::value, &host_tensor)
+ .ok());
+ test::ExpectTensorEqual<int_type>(host_tensor,
+ test::AsTensor<int_type>(int_values));
- EXPECT_TRUE(
- LiteralToHostTensor(int32_values_literal, DT_QINT32, &host_tensor)
- .ok());
- std::vector<qint32> qint32_values = {10, 11};
- test::ExpectTensorEqual<qint32>(host_tensor,
- test::AsTensor<qint32>(qint32_values));
+ EXPECT_TRUE(LiteralToHostTensor(int_values_literal,
+ DataTypeToEnum<qint_type>::value,
+ &host_tensor)
+ .ok());
+ std::vector<qint_type> qint_values = {10, 11};
+ test::ExpectTensorEqual<qint_type>(host_tensor,
+ test::AsTensor<qint_type>(qint_values));
- EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
- LiteralToHostTensor(int32_values_literal, DT_INT64, &host_tensor)
- .error_message());
- }
+ EXPECT_EQ(
+ error::INVALID_ARGUMENT,
+ LiteralToHostTensor(int_values_literal, DT_INT64, &host_tensor).code());
}
+} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 68cfdc1785..733eeed3c6 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -105,6 +105,36 @@ dimension_numbers: a serialized xla::DotDimensionNumbers proto.
precision_config: a serialized xla::PrecisionConfig proto.
)doc");
+REGISTER_OP("XlaDynamicSlice")
+ .Input("input: T")
+ .Input("start_indices: Tindices")
+ .Input("size_indices: Tindices")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Wraps the XLA DynamicSlice operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice
+.
+
+DynamicSlice extracts a sub-array from the input array at dynamic
+start_indices. The size of the slice in each dimension is passed in
+size_indices, which specify the end point of exclusive slice intervals in each
+dimension -- [start, start + size). The shape of start_indices must have rank 1,
+with dimension size equal to the rank of operand.
+
+input: A `Tensor` of type T.
+
+start_indices: Rank 1 tensor of N integers containing the starting indices of
+ the slice for each dimension. Value must be greater than or equal to zero.
+
+start_indices: List of N integers containing the slice size for each
+ dimension. Each value must be strictly greater than zero, and start + size
+ must be less than or equal to the size of the dimension to avoid
+ implementation defined behavior.
+)doc");
+
REGISTER_OP("XlaDynamicUpdateSlice")
.Input("input: T")
.Input("update: T")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 3626de375e..27dd18a9bb 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -291,13 +291,7 @@ def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
name=name)
-def dynamic_slice(x, starts, sizes, name=None):
- # TODO(phawkins): the Slice operator lowers to DynamicSlice if `starts` is not
- # a compile-time constant. This doesn't exactly mimic the semantics of dynamic
- # slice if the slice is out of bounds.
- return array_ops.slice(x, starts, sizes, name=name)
-
-
+dynamic_slice = gen_xla_ops.xla_dynamic_slice
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc
index 9d1992205b..b589512dcd 100644
--- a/tensorflow/compiler/tf2xla/shape_util.cc
+++ b/tensorflow/compiler/tf2xla/shape_util.cc
@@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
// Convert a TensorShape into the equivalent XLA Shape proto.
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape) {
+ xla::PrimitiveType type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
+ *shape = TensorShapeToXLAShape(type, tensor_shape);
+ return Status::OK();
+}
+
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+ const TensorShape& tensor_shape) {
int rank = tensor_shape.dims();
std::vector<int64> dimensions(rank);
std::vector<int64> layout(rank);
@@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
// XLA uses minor-to-major; Tensorflow uses major-to-minor.
std::iota(layout.rbegin(), layout.rend(), 0);
- xla::PrimitiveType type;
- TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
-
- *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
- return Status::OK();
+ return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h
index 58240b9c96..f7e34a5b40 100644
--- a/tensorflow/compiler/tf2xla/shape_util.h
+++ b/tensorflow/compiler/tf2xla/shape_util.h
@@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape,
Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
xla::Shape* shape);
+// Converts a TensorShape into the equivalent XLA Shape proto, taking an
+// xla::PrimitiveType to specify the element type. This never fails.
+xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type,
+ const TensorShape& tensor_shape);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 7dbe3a0b58..b22d53805d 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -340,6 +341,13 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(),
second_copy_def, g.get()));
TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping));
+
+ // Functionalize control flow.
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g.get(), &flib_def));
+ // After control flow functionalization, we might have more FunctionDef's
+ // (then/else branch, loop body). Add them to the graph.
+ TF_RETURN_IF_ERROR(g->AddFunctionLibrary(flib_def.ToProto()));
+
*graph = std::move(g);
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index 211caf8736..d6f42bac86 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -25,9 +25,12 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -75,6 +78,8 @@ Status CheckFeedFetchNameConflicts(const string& kind,
} // namespace
+const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
+
Status ValidateConfig(const tf2xla::Config& config) {
std::set<string> names;
for (const tf2xla::Feed& feed : config.feed()) {
@@ -323,4 +328,101 @@ uint32 GetXLARandomSeed() {
return counter.fetch_add(2);
}
+// TODO(b/77601805): add tests for associated function related stuff.
+bool HasAssociatedFunction(const NodeDef& node_def,
+ FunctionLibraryRuntime* flr) {
+ if (flr->GetFunctionLibraryDefinition()->Contains(node_def.op())) {
+ return true;
+ }
+
+ if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
+ // Skip gradient op. Gradient op has "f" attr, which is set to the function
+ // we are getting gradient for. That function is not associated with the op.
+ return false;
+ }
+
+ for (const auto& iter : node_def.attr()) {
+ if (iter.second.has_func()) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+ const Node& node, FunctionLibraryRuntime* flr) {
+ std::vector<AssociatedFunctionInfo> results;
+ const string& op = node.type_string();
+ if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
+ // This is a function call node.
+ AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+ results.emplace_back(AssociatedFunctionInfo(op, attrs));
+ } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
+ // Skip gradient op. Gradient op has "f" attr, which is set to the function
+ // we are getting gradient for. That function is not associated with the op.
+ } else {
+ // Collect all function attrs for the node.
+ for (auto& iter : node.attrs()) {
+ if (iter.second.has_func()) {
+ VLOG(2) << "Found function attr for node " << node.name() << ": "
+ << iter.first << " = " << iter.second.func().name();
+ results.emplace_back(AssociatedFunctionInfo(
+ iter.second.func().name(), iter.second.func().attr(), iter.first));
+ }
+ }
+ }
+ return results;
+}
+
+Status RewriteAssociatedFunction(
+ Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+ const AssociatedFunctionInfo& associated_function,
+ const string& rewritten_function_name) {
+ switch (associated_function.type()) {
+ case AssociatedFunctionInfo::kFunctionCallNode: {
+ // Change this node to call the new function.
+ NodeDefBuilder builder(node->name(), rewritten_function_name, fld);
+ for (auto attr : node->attrs()) {
+ builder.Attr(attr.first, attr.second);
+ }
+ for (int i = 0; i < node->num_inputs(); i++) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
+ builder.Input(input_node->name(), i, node->input_type(i));
+ }
+ builder.Device(node->assigned_device_name().empty()
+ ? node->requested_device()
+ : node->assigned_device_name());
+ NodeDef node_def;
+ TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
+ Status s;
+ Node* new_node = graph->AddNode(node_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ for (auto edge : node->in_edges()) {
+ graph->AddEdge(edge->src(), edge->src_output(), new_node,
+ edge->dst_input());
+ }
+ for (auto edge : node->out_edges()) {
+ graph->AddEdge(new_node, edge->src_output(), edge->dst(),
+ edge->dst_input());
+ }
+ graph->RemoveNode(node);
+ break;
+ }
+ case AssociatedFunctionInfo::kFunctionAttr: {
+ // Change function attr to rewritten functions.
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
+ node->ClearAttr(associated_function.attr_name());
+ func.set_name(rewritten_function_name);
+ node->AddAttr(associated_function.attr_name(), func);
+ break;
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index a29e764466..6065d0bb9a 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/op.h"
@@ -59,6 +60,67 @@ void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype,
// Returns the next random seed to use for seeding xla rng.
uint32 GetXLARandomSeed();
+// Indicates how a FunctionDef is associated with a graph node (e.g. the node is
+// a function call, or the node has function attrs).
+class AssociatedFunctionInfo {
+ public:
+ enum AssociatedFunctionType {
+ kFunctionCallNode = 0,
+ kFunctionAttr = 1,
+ };
+
+ // The node is a function call.
+ AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
+ : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
+
+ // The function is an attr of the node.
+ AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
+ const string& attr_name)
+ : type_(kFunctionAttr),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
+ AssociatedFunctionType type() const { return type_; }
+
+ const string& func_name() const { return func_name_; }
+
+ const string& attr_name() const { return attr_name_; }
+
+ const AttrValueMap& attrs() const { return attrs_; }
+
+ private:
+ // Available for all instances.
+ AssociatedFunctionType type_;
+ string func_name_;
+ AttrValueMap attrs_;
+
+ // Only available if the function is defined in an attr.
+ string attr_name_;
+};
+
+// Returns if the NodeDef has associated function.
+bool HasAssociatedFunction(const NodeDef& node_def,
+ FunctionLibraryRuntime* flr);
+
+// Gets functions associated with the node. Current cases:
+// 1. For function call node, its function name;
+// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
+ const Node& node, FunctionLibraryRuntime* flr);
+
+// Changes associated functions for the node. Current cases:
+// 1. For function call node, creates a new node with the new function name and
+// remove the old node;
+// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+Status RewriteAssociatedFunction(
+ Graph* graph, Node* node, FunctionLibraryDefinition* fld,
+ const AssociatedFunctionInfo& associated_function,
+ const string& rewritten_function_name);
+
+// Attribute to mark nodes to be executed on host.
+extern const char kXlaOutsideCompilationAttrName[];
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc
index c969212a1b..d00b137662 100644
--- a/tensorflow/compiler/tf2xla/type_util.cc
+++ b/tensorflow/compiler/tf2xla/type_util.cc
@@ -26,21 +26,26 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
*type = xla::PRED;
return Status::OK();
case tensorflow::DT_INT8:
+ case tensorflow::DT_QINT8:
*type = xla::S8;
return Status::OK();
case tensorflow::DT_INT16:
+ case tensorflow::DT_QINT16:
*type = xla::S16;
return Status::OK();
case tensorflow::DT_INT32:
+ case tensorflow::DT_QINT32:
*type = xla::S32;
return Status::OK();
case tensorflow::DT_INT64:
*type = xla::S64;
return Status::OK();
case tensorflow::DT_UINT8:
+ case tensorflow::DT_QUINT8:
*type = xla::U8;
return Status::OK();
case tensorflow::DT_UINT16:
+ case tensorflow::DT_QUINT16:
*type = xla::U16;
return Status::OK();
case tensorflow::DT_UINT32:
@@ -64,12 +69,6 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
case tensorflow::DT_COMPLEX64:
*type = xla::C64;
return Status::OK();
- case tensorflow::DT_QUINT8:
- *type = xla::U8;
- return Status::OK();
- case tensorflow::DT_QINT32:
- *type = xla::S32;
- return Status::OK();
default:
return errors::InvalidArgument(
"Unsupported type in DataTypeToPrimitiveType ",
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index dcb455779d..d5094e8ec5 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
-#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
@@ -150,6 +149,9 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function,
TF_RETURN_WITH_CONTEXT_IF_ERROR(
GetFunctionBody(function, flib_runtime_, fbody),
"Local lookup failed with: ", status.error_message());
+ VLOG(4) << "Function " << function.name() << " in flib_runtime_";
+ } else {
+ VLOG(4) << "Function " << function.name() << " in local_flib_runtime_";
}
return Status::OK();
}
@@ -323,8 +325,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
step_container->name(), XlaContext::kXlaContextResourceName,
xla_context));
- GraphCompiler graph_compiler(xla_context, device, graph.get(), flib,
- step_container.get());
+ GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
TF_RETURN_IF_ERROR(graph_compiler.Compile());
// Explicitly clean up the step container, to capture the cleanup status.
step_container.reset();
@@ -332,10 +333,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
}
// Builds the XLA computation.
-//
-// `retvals` is the list of retvals produced by _Retval operators, in index
-// order. `variable_map` is a map from variable ID numbers to XlaOpContext
-// variable states, generated by the symbolic evaluation.
+// `args` is the list of input arguments, `retvals` is the list of retvals
+// produced by _Retval operators, in index order.
// If `return_updated_values_for_all_resources` is true, all resources will be
// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
@@ -743,18 +742,13 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
if (VLOG_IS_ON(2)) {
VLOG(2) << "XlaCompiler::CompileGraph: "
<< dump_graph::DumpGraphToFile(
- absl::StrCat("xla_compile_graph_", name), *graph);
+ absl::StrCat("xla_compile_graph_", name), *graph,
+ flib_runtime_->GetFunctionLibraryDefinition());
}
// Report the error here if initialization failed.
TF_RETURN_IF_ERROR(initialization_status_);
- // Converts Tensorflow's graph control-flow constructs into functional
- // control-flow that can be compiled into XLA code.
- TF_RETURN_IF_ERROR(
- FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(),
- graph.get(), local_flib_def_.get()));
-
// Detect invalid nodes.
// FunctionalizeControlFlow may remove some nodes from the graph.
TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def,
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 70efa7781d..72b17d04fc 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -604,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
auto instr1 = c1.instructions(j);
auto instr2 = c2.instructions(j);
instr1.clear_name();
+ instr1.clear_id();
+ instr1.clear_operand_ids();
instr2.clear_name();
- // The names of instructions were uniquified by the XlaBuilder, the rest
- // of the fields should be identical.
+ instr2.clear_id();
+ instr2.clear_operand_ids();
+ // The names of instructions were uniquified by the XlaBuilder and the
+ // unique ids may be different, the rest of the fields should be
+ // identical.
string str1, str2;
+ LOG(INFO) << "instr1 = " << instr1.DebugString();
+ LOG(INFO) << "instr2 = " << instr2.DebugString();
instr1.AppendPartialToString(&str1);
instr2.AppendPartialToString(&str2);
EXPECT_EQ(str1, str2);
@@ -1219,25 +1226,8 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompilationResult result;
- status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
- std::move(graph_copy), args, &result);
- ASSERT_FALSE(status.ok());
- EXPECT_TRUE(
- absl::StrContains(status.error_message(),
- "The following nodes are unreachable "
- "from the source in the graph: {{node NoOp}}"))
- << status.error_message();
- }
-
- // Fix control edges for NoOp.
- {
- std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
- CopyGraph(*graph, graph_copy.get());
- EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get()));
- XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
std::move(graph_copy), args, &result));
- EXPECT_EQ(0, result.resource_updates.size());
}
}
diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
index 23d04d43b3..bc44301d40 100644
--- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
+++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc
@@ -20,21 +20,6 @@ limitations under the License.
namespace tensorflow {
bool CpuOpFilter(KernelDef* kdef) {
- // TODO(b/34339814): implement inverse erf for double types and remove this
- // workaround.
- if (kdef->op() == "RandomStandardNormal") {
- kdef->clear_constraint();
- // Change the type constraint to permit only DTD_FLOAT.
- KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint();
- attr_constraint->set_name("dtype");
- attr_constraint->mutable_allowed_values()->mutable_list()->add_type(
- DT_FLOAT);
- return true;
- }
- // TODO(b/26783907): The CPU backend currently does not implement sort.
- if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") {
- return false;
- }
if (kdef->op() == "Const") {
AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef);
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index d10a504da0..2a9eaeee14 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -83,6 +83,10 @@ DataType XlaOpKernelContext::input_type(int index) const {
return context_->input(index).dtype();
}
+DataType XlaOpKernelContext::InputType(absl::string_view name) {
+ return GetInputTensorByName(name).dtype();
+}
+
xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(input_type(index), &type);
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 962c86d3a5..a3a0d10cc0 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -71,6 +71,9 @@ class XlaOpKernelContext {
// Returns the type of input `index`.
DataType input_type(int index) const;
+ // Returns the type of input `name`.
+ DataType InputType(absl::string_view name);
+
// Returns the type of input `index` as an xla::PrimitiveType. If the type
// is not representable as an XLA type, sets an error status and returns
// xla::PRIMITIVE_TYPE_INVALID.
@@ -79,7 +82,7 @@ class XlaOpKernelContext {
// Returns the shape of input `index`.
TensorShape InputShape(int index);
- // Returns the shape of input `name`.
+ // Returns the shape of input with name `name`.
TensorShape InputShape(absl::string_view name);
// Returns input `index` as a XlaOp. Unlike
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index b0eeee3174..91d48125f1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default;
<< " have incompatible compile time constant inputs.";
return false;
}
+ if (x.is_metadata_op != y.is_metadata_op) {
+ LOG(WARNING) << "Registrations of " << x.name
+ << " have incompatible values for is_metadata_op.";
+ return false;
+ }
return true;
}
@@ -350,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
return &it->second.front()->compile_time_constant_inputs;
}
+/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) {
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ auto it = registry.ops_.find(op);
+ if (it == registry.ops_.end() || it->second.empty()) {
+ return false;
+ }
+
+ // The test in IsCompatible ensures that if there are multiple matching
+ // registrations for this op name, they all have the same value of
+ // is_metadata_op, so only the first match is returned.
+ return it->second.front()->is_metadata_op;
+}
+
std::vector<string> XlaOpRegistry::BackendNames() {
std::vector<string> names;
XlaOpRegistry& registry = Instance();
@@ -432,6 +451,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
return *this;
}
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
+ registration_->is_metadata_op = true;
+ return *this;
+}
+
std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
XlaOpRegistry::Factory factory) {
registration_->factory = factory;
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 74a4885f1f..4b2c2bacd6 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -47,17 +47,18 @@ extern const char* const DEVICE_XLA_GPU;
constexpr std::array<DataType, 4> kFloatTypes = {
{DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
-constexpr std::array<DataType, 9> kNumericTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BFLOAT16}};
+constexpr std::array<DataType, 11> kNumericTypes = {
+ {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
+ DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}};
-constexpr std::array<DataType, 9> kCpuAllTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 14> kCpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
-constexpr std::array<DataType, 10> kGpuAllTypes = {
- {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 15> kGpuAllTypes = {
+ {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
+ DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
+ DT_BFLOAT16}};
// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.
@@ -136,6 +137,10 @@ class XlaOpRegistry {
static const std::unordered_set<string>* CompileTimeConstantInputs(
const string& op);
+ // Returns true if `op` is a "metadata" op, one that only looks at the shapes
+ // of its operands and not their values.
+ static bool IsMetadataOp(const string& op);
+
private:
friend class XlaBackendRegistrar;
friend class XlaOpRegistrar;
@@ -192,6 +197,10 @@ class XlaOpRegistry {
// Names of arguments that must be compile-time constants.
std::unordered_set<string> compile_time_constant_inputs;
+ // True if this is a "metadata" op, one that only looks at the shapes of its
+ // operands and not their values.
+ bool is_metadata_op = false;
+
// Factory used to build OpKernels that perform symbolic execution.
Factory factory;
};
@@ -256,6 +265,10 @@ class XlaOpRegistrationBuilder {
// Mark 'input_name' as an argument whose value must be known at compile-time.
XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name);
+ // Mark this op as a "metadata" op, one that only looks at the shapes of its
+ // operands and not their values.
+ XlaOpRegistrationBuilder& IsMetadataOp();
+
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 76e36f3c46..cc7390c6e6 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -193,6 +193,7 @@ cc_library(
":types",
":util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/synchronization",
],
)
@@ -244,6 +245,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 4e1ff9e5c0..95ff6432a5 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
TF_RETURN_IF_ERROR(first_error_);
- TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
+ LookUpInstructionByHandle(root_id));
ProgramShape program_shape;
- *program_shape.mutable_result() = instructions_[root_id].shape();
+ *program_shape.mutable_result() = root_proto->shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
@@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
return;
}
- CHECK(op_handle < instructions_.size() && op_handle >= 0);
-
- const HloInstructionProto& instr = instructions_[op_handle];
+ const HloInstructionProto& instr =
+ *(LookUpInstructionByHandle(op_handle).ValueOrDie());
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
switch (opcode) {
default:
@@ -283,6 +283,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
// Clear data held by this builder.
this->instructions_.clear();
+ this->handle_to_index_.clear();
this->embedded_.clear();
this->parameter_numbers_.clear();
@@ -2285,7 +2286,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
*program_shape->mutable_result() = root->shape();
// We use std::set to keep the instruction ids in ascending order (which is
- // also a valid denpendency order). The related ops will be added to the
+ // also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64> related_ops;
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
@@ -2293,14 +2294,16 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
worklist.push(root->id());
related_ops.insert(root->id());
while (!worklist.empty()) {
- int64 node = worklist.front();
+ int64 handle = worklist.front();
worklist.pop();
- for (int64 id : instructions_[node].operand_ids()) {
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
+ LookUpInstructionByHandle(handle));
+ for (int64 id : instr_proto->operand_ids()) {
if (related_ops.insert(id).second) {
worklist.push(id);
}
}
- for (int64 called_id : instructions_[node].called_computation_ids()) {
+ for (int64 called_id : instr_proto->called_computation_ids()) {
related_calls.insert(called_id);
}
}
@@ -2308,7 +2311,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
// Add related ops to the computation.
for (int64 id : related_ops) {
auto* instr = entry.add_instructions();
- *instr = instructions_[id];
+ TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
+ LookUpInstructionByHandle(id));
+ *instr = *instr_src;
// Ensures that the instruction names are unique among the graph.
const string& new_name =
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
@@ -2415,11 +2420,11 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
absl::Span<const XlaOp> operands) {
TF_RETURN_IF_ERROR(first_error_);
- const int64 handle = instructions_.size();
+ const int64 handle = GetUniqueId();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
- instr.set_name(StrCat(instr.opcode()));
+ instr.set_name(instr.opcode());
}
for (const auto& operand : operands) {
if (operand.builder_ == nullptr) {
@@ -2437,7 +2442,8 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
*instr.mutable_sharding() = *sharding_;
}
- instructions_.push_back(instr);
+ handle_to_index_[handle] = instructions_.size();
+ instructions_.push_back(std::move(instr));
XlaOp op(handle, this);
return op;
@@ -2467,10 +2473,16 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
op.handle(), op.builder_->name(), this->name());
}
- if (op.handle() >= instructions_.size() || op.handle() < 0) {
- return InvalidArgument("no XlaOp value %d", op.handle());
+ return LookUpInstructionByHandle(op.handle());
+}
+
+StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
+ int64 handle) const {
+ auto it = handle_to_index_.find(handle);
+ if (it == handle_to_index_.end()) {
+ return InvalidArgument("No XlaOp with handle %d", handle);
}
- return &instructions_[op.handle()];
+ return &instructions_[it->second];
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 833eafcf85..d0c59fa6f2 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
@@ -955,6 +956,8 @@ class XlaBuilder {
HloInstructionProto* instr);
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
+ StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
+ int64 handle) const;
// Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@@ -1024,6 +1027,10 @@ class XlaBuilder {
// The instructions of this computation.
std::vector<HloInstructionProto> instructions_;
+ // A map from XlaOp::Handle to the index in the instructions_ vector where the
+ // instruction is held.
+ tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
+
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
// that XlaComputation.
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 0d3136b0cc..3ed3afcfce 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -57,6 +57,8 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
// regression.
flags->set_xla_cpu_enable_fast_math(true);
flags->set_xla_gpu_enable_fast_math(true);
+
+ flags->set_xla_force_host_platform_device_count(1);
}
// Allocates flag_values and flag_objects; this function must not be called more
@@ -323,6 +325,17 @@ void AllocateFlags() {
flag_values->xla_gpu_crash_on_verification_failures(),
"Crashes the program on extra verification failures, e.g. cuDNN "
"cross checking failures"),
+ tensorflow::Flag(
+ "xla_force_host_platform_device_count",
+ int32_setter_for(
+ &DebugOptions::set_xla_force_host_platform_device_count),
+ flag_values->xla_force_host_platform_device_count(),
+ "Force the host platform to pretend that there are these many "
+ "host \"devices\". All of these host devices are backed by the same"
+ "threadpool. Setting this to anything other than 1 can increase "
+ "overhead from context switching but we let the user override this "
+ "behavior to help run tests on the host that run models in parallel "
+ "across multiple devices."),
});
ParseFlagsFromEnv(*flag_objects);
}
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index f1f255efae..5035f41988 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -1351,17 +1351,8 @@ StatusOr<Literal> LiteralBase::BitcastConvert(
return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true);
}
-StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape,
- bool round_f32_to_bf16) const {
+StatusOr<Literal> LiteralBase::ConvertToShape(const Shape& dest_shape) const {
if (!ShapeUtil::IsTuple(dest_shape)) {
- if (round_f32_to_bf16 && shape().element_type() == F32 &&
- dest_shape.element_type() == BF16) {
- auto converter = [](float src) {
- return tensorflow::bfloat16::round_to_bfloat16(src);
- };
- return ConvertBetweenNativeTypesWithConverter<float, bfloat16>(*this,
- converter);
- }
return Convert(dest_shape.element_type());
}
std::vector<Literal> elements;
@@ -1769,6 +1760,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
case PRED:
CopyToRepeatedField(proto->mutable_preds(), data<bool>());
break;
+ case S8:
+ proto->set_s8s(static_cast<const signed char*>(data<int8>().data()),
+ element_count());
+ break;
case U8:
proto->set_u8s(static_cast<const unsigned char*>(data<uint8>().data()),
element_count());
@@ -1859,6 +1854,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
break;
+ case S8: {
+ auto s8_data = data<int8>();
+ TF_RET_CHECK(proto.s8s().size() == s8_data.size());
+ std::copy(proto.s8s().begin(), proto.s8s().end(), s8_data.begin());
+ } break;
case U8: {
auto u8_data = data<uint8>();
TF_RET_CHECK(proto.u8s().size() == u8_data.size());
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index fa5b5f7fab..3cd3541fe1 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -203,6 +203,10 @@ class LiteralBase {
// Returns the count of the elements in the array at the given shape index in
// this literal.
int64 element_count(const ShapeIndex& index = {}) const {
+ if (index.empty()) {
+ // Common case, avoid GetSubshape().
+ return ShapeUtil::ElementsIn(shape());
+ }
return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index));
}
@@ -217,14 +221,7 @@ class LiteralBase {
// Converts this literal to the given shape. Returns an error is the
// conversion is not possible.
- //
- // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding
- // instead of truncation; otherwise, truncation is used.
- //
- // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes
- // the default behavior.
- StatusOr<Literal> ConvertToShape(const Shape& dest_shape,
- bool round_f32_to_bf16 = false) const;
+ StatusOr<Literal> ConvertToShape(const Shape& dest_shape) const;
// Converts this literal to another primitive type using a bitcast
// conversion. The to and from primitive types must have the same bit
@@ -859,9 +856,9 @@ class BorrowingLiteral : public LiteralBase {
template <typename NativeT>
absl::Span<const NativeT> LiteralBase::Piece::data() const {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
+ DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ DCHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
<< "Attempting to access "
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
@@ -872,9 +869,9 @@ absl::Span<const NativeT> LiteralBase::Piece::data() const {
template <typename NativeT>
absl::Span<NativeT> LiteralBase::Piece::data() {
- CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
- CHECK_EQ(subshape().element_type(),
- primitive_util::NativeToPrimitiveType<NativeT>())
+ DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape());
+ DCHECK_EQ(subshape().element_type(),
+ primitive_util::NativeToPrimitiveType<NativeT>())
<< "Attempting to access "
<< PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>())
<< " type, but literal element type is "
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index ba7fd29a62..7ad287c897 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -1640,6 +1640,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
auto one_f32 = LiteralUtil::CreateR0<float>(1.0);
auto two_f32 = LiteralUtil::CreateR0<float>(2.0);
auto vector_int8 = LiteralUtil::CreateR1<int8>({-128, 0, 2, 4, 7, 56, 127});
+ auto vector_uint8 = LiteralUtil::CreateR1<uint8>({128, 0, 2, 56, 127, 255});
auto vector_c64 = LiteralUtil::CreateR1<complex64>({{1.0, 2.0}, {3.0, 4.0}});
auto vector_bfloat16 = LiteralUtil::CreateR1<bfloat16>(
{bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}});
@@ -1658,6 +1659,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
};
EXPECT_EQ(one_f32, to_from_proto(one_f32));
+ EXPECT_EQ(vector_int8, to_from_proto(vector_int8));
+ EXPECT_EQ(vector_uint8, to_from_proto(vector_uint8));
EXPECT_EQ(vector_c64, to_from_proto(vector_c64));
EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16));
EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred));
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
index 787725e884..b507a2ef79 100644
--- a/tensorflow/compiler/xla/protobuf_util.cc
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
@@ -49,16 +50,40 @@ string SanitizeFilename(const string& file_name) {
return safe_file_name;
}
+std::pair<tensorflow::mutex*, std::vector<std::function<string(string)>>*>
+GetDirectoryExpanders() {
+ static auto* mutex = new tensorflow::mutex;
+ static auto* singleton = new std::vector<std::function<string(string)>>;
+ return {mutex, singleton};
+}
+
+// Runs all the directory expanders over x and returns the result.
+string Expand(string x) {
+ auto pair = GetDirectoryExpanders();
+ tensorflow::mutex_lock lock(*pair.first);
+ for (const auto& f : *pair.second) {
+ x = f(x);
+ }
+ return x;
+}
+
} // namespace
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
tensorflow::Env* env = tensorflow::Env::Default();
- TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
+ string expanded_dir = Expand(directory);
+ TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(expanded_dir));
string safe_file_name = SanitizeFileName(file_name) + ".pb";
- const string path = tensorflow::io::JoinPath(directory, safe_file_name);
+ const string path = tensorflow::io::JoinPath(expanded_dir, safe_file_name);
return tensorflow::WriteBinaryProto(env, path, message);
}
+void RegisterDirectoryExpander(const std::function<string(string)>& expander) {
+ auto pair = GetDirectoryExpanders();
+ tensorflow::mutex_lock lock(*pair.first);
+ pair.second->push_back(expander);
+}
+
} // namespace protobuf_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h
index 3667621367..f22fc8b849 100644
--- a/tensorflow/compiler/xla/protobuf_util.h
+++ b/tensorflow/compiler/xla/protobuf_util.h
@@ -39,6 +39,10 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
+// Registers a function that may either expand a dirpath or forward the original
+// dirpath along as-is.
+void RegisterDirectoryExpander(const std::function<string(string)>& expander);
+
} // namespace protobuf_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 9da5dc0d2d..cd5fd33029 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -469,9 +469,11 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated(
absl::Span<const int64> window_strides,
absl::Span<const std::pair<int64, int64>> padding,
absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers) {
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count) {
return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
- lhs_dilation, rhs_dilation, dimension_numbers);
+ lhs_dilation, rhs_dilation, dimension_numbers,
+ feature_group_count);
}
LocalOp LocalComputationBuilder::ConvertElementType(
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 1d5dfe5911..2166bb6721 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -248,7 +248,8 @@ class LocalComputationBuilder {
absl::Span<const std::pair<int64, int64> > padding,
absl::Span<const int64> lhs_dilation,
absl::Span<const int64> rhs_dilation,
- const ConvolutionDimensionNumbers& dimension_numbers);
+ const ConvolutionDimensionNumbers& dimension_numbers,
+ int64 feature_group_count);
LocalOp ConvertElementType(const LocalOp& operand,
PrimitiveType new_element_type);
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index fa4366ff07..bb303c5678 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -1109,7 +1109,7 @@ class ComputationBuilder(object):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return self._client.DotGeneral(lhs, rhs, dimension_numbers)
- def Conv(self, lhs, rhs, window_strides, padding):
+ def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1):
"""Enqueues a Conv operation onto the computation.
Args:
@@ -1117,6 +1117,7 @@ class ComputationBuilder(object):
rhs: LocalOp for the rank N+2 array of kernel weights.
window_strides: length-N array-like of integer kernel strides.
padding: PaddingType representing either 'SAME' or 'VALID' padding.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the Conv operation.
"""
@@ -1125,10 +1126,11 @@ class ComputationBuilder(object):
self.GetShape(rhs).dimensions()[2:], window_strides)
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (),
- (), dimension_numbers)
+ (), dimension_numbers,
+ feature_group_count)
def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding,
- lhs_dilation, rhs_dilation):
+ lhs_dilation, rhs_dilation, feature_group_count=1):
"""Enqueues a ConvWithGeneralPadding operation onto the computation.
Args:
@@ -1138,6 +1140,7 @@ class ComputationBuilder(object):
padding: length-N array-like of pairs of integers of (low, high) padding.
lhs_dilation: length-N array-like of dilation factors.
rhs_dilation: length-N array-like of dilation factors.
+ feature_group_count: number of feature groups for grouped convolution.
Returns:
A ComputationdataHandle representing the added ConvWithGeneralPadding op.
@@ -1145,7 +1148,8 @@ class ComputationBuilder(object):
dimension_numbers = self._GetConvDimensionNumbers(len(window_strides))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def _GetConvDimensionNumbers(self, num_spatial_dims):
"""Create ConvolutionDimensionNumbers proto for convolutions."""
@@ -1163,7 +1167,8 @@ class ComputationBuilder(object):
return dimension_numbers
def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation,
- rhs_dilation, dimension_numbers):
+ rhs_dilation, dimension_numbers,
+ feature_group_count=1):
"""Enqueues a ConvGeneralDilated operation onto the computation.
Args:
@@ -1190,6 +1195,7 @@ class ComputationBuilder(object):
labels appear in the rhs_spec string, so that window_strides[0] is
matched with the dimension corresponding to the first character
appearing in rhs_spec that is not 'I' or 'O'.
+ feature_group_count: number of feature groups for grouped convolution.
Returns: a LocalOp representing the ConvGenralDilated operation.
"""
@@ -1215,7 +1221,8 @@ class ComputationBuilder(object):
key=lambda i: rhs_spec.index(out_spec[i])))
return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding,
lhs_dilation, rhs_dilation,
- dimension_numbers)
+ dimension_numbers,
+ feature_group_count)
def Sort(self, operand, dimension=-1):
"""Enqueues a sort operation onto the computation."""
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index fd98e19457..82103f0313 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest):
[40., 50., 0.]]]])
self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2)))
+ def testConvGeneralDilatedGroupedConvolutionF32(self):
+ c = self._NewComputation()
+ a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
+ lhs = a(1, 2, 2, 3)
+ rhs = a(2, 1, 1, 2) * 10
+ strides = [1, 1]
+ pads = [(1, 0), (0, 1)]
+ lhs_dilation = (2, 1)
+ rhs_dilation = (1, 1)
+ dimension_numbers = ("NCHW", "OIHW", "NCHW")
+ feature_group_count = 2
+ c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs),
+ strides, pads, lhs_dilation, rhs_dilation,
+ dimension_numbers, feature_group_count)
+ result = np.array([[[[0., 0., 0.],
+ [10., 20., 0.],
+ [0., 0., 0.],
+ [40., 50., 0.]],
+ [[0., 0., 0.],
+ [330., 380., 160.],
+ [0., 0., 0.],
+ [480., 530., 220.]]]])
+ self._ExecuteAndCompareClose(c, expected=result)
+
def testBooleanNot(self):
c = self._NewComputation()
arr = NumpyArrayBool([True, False, True])
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 05325367f5..ceb5e74db7 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -186,11 +186,10 @@ ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
/* static */ std::unique_ptr<std::vector<float>>
ReferenceUtil::ReduceWindow1DGeneric(
- const absl::Span<const float>& operand, float init,
+ absl::Span<const float> operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
std::vector<int64> window_counts(window.size(), 0);
std::vector<int64> pad_low(window.size(), 0);
@@ -218,10 +217,9 @@ ReferenceUtil::ReduceWindow1DGeneric(
}
/* static */ std::unique_ptr<std::vector<float>>
-ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
- float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
+ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
+ absl::Span<const int64> window,
+ absl::Span<const int64> stride,
Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
@@ -234,9 +232,8 @@ ReferenceUtil::ReduceWindow1DAdd(const absl::Span<const float>& operand,
ReferenceUtil::ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{operand.height(), operand.width()};
std::vector<int64> window_counts(window.size(), 0);
@@ -273,9 +270,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
}
/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
- const Array2D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array2D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
std::vector<int64> dim_lengths{operand.height(), operand.width()};
return ReduceWindow2DGeneric(
@@ -284,9 +280,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
}
/* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
- const Array3D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array3D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
@@ -332,8 +327,8 @@ ReferenceUtil::ReduceWindow2DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
return ReduceWindow4DGeneric(
@@ -345,9 +340,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding) {
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding) {
std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
operand.n4()};
@@ -399,9 +393,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
}
/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
- const Array4D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding) {
+ const Array4D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding) {
const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
padding);
@@ -425,8 +418,8 @@ ReferenceUtil::ReduceWindow4DGeneric(
ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
const Array4D<float>& source,
float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
+ absl::Span<const int64> window,
+ absl::Span<const int64> stride,
bool same_padding) {
Padding padding = same_padding ? Padding::kSame : Padding::kValid;
auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 9ce098029d..8654fbb9b5 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -177,47 +177,41 @@ class ReferenceUtil {
// Windowed reductions with Add as the function to apply.
static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
- const absl::Span<const float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ absl::Span<const float> operand, float init,
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
- const Array2D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array2D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
- const Array3D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array3D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
- const Array4D<float>& operand, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ const Array4D<float>& operand, float init, absl::Span<const int64> window,
+ absl::Span<const int64> stride, Padding padding);
// Windowed reductions with a generic reduce function.
static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
- const absl::Span<const float>& operand, float init,
+ absl::Span<const float> operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
static std::unique_ptr<Array2D<float>> ReduceWindow2DGeneric(
const Array2D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, Padding padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ Padding padding);
// With arbitrary padding.
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride,
- const absl::Span<const std::pair<int64, int64>>& padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ absl::Span<const std::pair<int64, int64>> padding);
// Batch normalize data.
static std::unique_ptr<Array4D<float>> BatchNorm4D(
@@ -230,8 +224,8 @@ class ReferenceUtil {
// TODO(b/74533103) Switch tests to evaluator and remove this implementation.
static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
const Array4D<float>& operand, const Array4D<float>& source, float init,
- const absl::Span<const int64>& window,
- const absl::Span<const int64>& stride, bool same_padding);
+ absl::Span<const int64> window, absl::Span<const int64> stride,
+ bool same_padding);
// Concatenates the lhs and rhs arrays along the concatenate_dimension.
// E.g. if concatenate_dimension is 0, the "n1"/height dimension is
@@ -332,8 +326,8 @@ class ReferenceUtil {
// Slices with index clamping
template <typename T>
- static std::vector<T> ClampSlice1D(const absl::Span<const T>& input,
- int64 start, int64 size) {
+ static std::vector<T> ClampSlice1D(absl::Span<const T> input, int64 start,
+ int64 size) {
start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
std::vector<T> result;
for (int64 i = 0; i < size; ++i) {
diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD
index 97fcd37f6b..3abb3855a4 100644
--- a/tensorflow/compiler/xla/rpc/BUILD
+++ b/tensorflow/compiler/xla/rpc/BUILD
@@ -34,19 +34,28 @@ cc_library(
],
)
-tf_cc_binary(
- name = "grpc_service_main_cpu",
+cc_library(
+ name = "grpc_service_main_library",
srcs = ["grpc_service_main.cc"],
deps = [
":grpc_service",
"//tensorflow:grpc++",
"//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings:str_format",
],
)
+tf_cc_binary(
+ name = "grpc_service_main_cpu",
+ deps = [
+ ":grpc_service_main_library",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ ],
+)
+
tf_cc_test(
name = "grpc_client_test",
srcs = ["grpc_client_test.cc"],
diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
index d6b5149a24..522ab99fb1 100644
--- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc
+++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "grpcpp/server_builder.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/xla/rpc/grpc_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -29,8 +30,15 @@ namespace {
int RealMain(int argc, char** argv) {
int32 port = 1685;
+ bool any_address = false;
+ string platform_str;
std::vector<tensorflow::Flag> flag_list = {
- tensorflow::Flag("port", &port, "port to listen on"),
+ tensorflow::Flag("platform", &platform_str,
+ "The XLA platform this service should be bound to"),
+ tensorflow::Flag("port", &port, "The TCP port to listen on"),
+ tensorflow::Flag(
+ "any", &any_address,
+ "Whether to listen to any host address or simply localhost"),
};
string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
@@ -40,19 +48,24 @@ int RealMain(int argc, char** argv) {
}
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ se::Platform* platform = nullptr;
+ if (!platform_str.empty()) {
+ platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie();
+ }
std::unique_ptr<xla::GRPCService> service =
- xla::GRPCService::NewService().ConsumeValueOrDie();
+ xla::GRPCService::NewService(platform).ConsumeValueOrDie();
::grpc::ServerBuilder builder;
- string server_address(absl::StrFormat("localhost:%d", port));
+ string server_address(
+ absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port));
+ builder.SetMaxReceiveMessageSize(INT_MAX);
builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials());
builder.RegisterService(service.get());
std::unique_ptr<::grpc::Server> server(builder.BuildAndStart());
LOG(INFO) << "Server listening on " << server_address;
server->Wait();
-
return 0;
}
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f4e24bff34..e800cf470c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -365,8 +365,11 @@ cc_library(
hdrs = ["pattern_matcher.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
+ "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/utility",
],
)
@@ -551,6 +554,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
@@ -589,6 +593,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
@@ -1146,6 +1151,38 @@ tf_cc_test(
)
cc_library(
+ name = "hlo_module_group",
+ srcs = ["hlo_module_group.cc"],
+ hdrs = ["hlo_module_group.h"],
+ deps = [
+ ":hlo",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:statusor",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_module_group_test",
+ srcs = ["hlo_module_group_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_matchers",
+ ":hlo_module_group",
+ ":hlo_module_group_metadata",
+ ":hlo_parser",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
name = "hlo_module_group_metadata",
srcs = ["hlo_module_group_metadata.cc"],
hdrs = ["hlo_module_group_metadata.h"],
@@ -1267,6 +1304,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/memory",
],
)
@@ -1400,6 +1438,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
@@ -1786,6 +1825,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"@com_google_absl//absl/memory",
@@ -1962,6 +2002,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
+ ":hlo_memory_scheduler",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@@ -2521,6 +2562,7 @@ cc_library(
],
deps = [
":hlo",
+ ":hlo_module_group",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -2552,6 +2594,26 @@ cc_library(
],
)
+tf_cc_test(
+ name = "hlo_pass_pipeline_test",
+ srcs = ["hlo_pass_pipeline_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_parser",
+ ":hlo_pass_pipeline",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
cc_library(
name = "hlo_cse",
srcs = ["hlo_cse.cc"],
@@ -2623,6 +2685,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index c88a3a3b4b..75dae7a714 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -745,12 +745,25 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
}
const int64 rhs_kept_dim = 1 - rhs_collapsing_dim;
- auto reshape_if_necessary = [&](HloInstruction* hlo) {
- if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
+ auto as_type = [&](HloInstruction* hlo, const PrimitiveType element_type) {
+ if (hlo->shape().element_type() == element_type) {
return hlo;
}
- return computation_->AddInstruction(
- HloInstruction::CreateReshape(dot->shape(), hlo));
+ return computation_->AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(hlo->shape(), element_type), hlo));
+ };
+
+ auto reshape_if_necessary = [&](HloInstruction* hlo) {
+ hlo = as_type(hlo, dot->shape().element_type());
+ if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
+ hlo = computation_->AddInstruction(
+ HloInstruction::CreateReshape(dot->shape(), hlo));
+ }
+ return hlo;
+ };
+
+ auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
+ return AddReduce(as_type(hlo, F32), dim);
};
auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape,
@@ -770,7 +783,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
if (ShapeUtil::Rank(rhs->shape()) == 1 &&
ShapeUtil::Rank(lhs->shape()) == 1) {
TF_RETURN_IF_ERROR(
- ReplaceInstruction(dot, reshape_if_necessary(AddReduce(
+ ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
multiply(Flatten(lhs), Flatten(rhs)), 0))));
return true;
}
@@ -804,17 +817,17 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
(ShapeUtil::Rank(lhs->shape()) == 2 &&
lhs->shape().dimensions(lhs_kept_dim) == 1)) {
if (ShapeUtil::Rank(rhs->shape()) == 1) {
- TF_RETURN_IF_ERROR(ReplaceInstruction(
- dot,
- reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0))));
+ TF_RETURN_IF_ERROR(
+ ReplaceInstruction(dot, reshape_if_necessary(add_reduce_in_f32(
+ multiply(Flatten(lhs), rhs), 0))));
return true;
}
TF_RETURN_IF_ERROR(ReplaceInstruction(
- dot, reshape_if_necessary(
- AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
- rhs_collapsing_dim),
- rhs),
- rhs_collapsing_dim))));
+ dot, reshape_if_necessary(add_reduce_in_f32(
+ multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(),
+ rhs_collapsing_dim),
+ rhs),
+ rhs_collapsing_dim))));
return true;
}
@@ -826,7 +839,7 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
(ShapeUtil::Rank(rhs->shape()) == 2 &&
rhs->shape().dimensions(rhs_kept_dim) == 1)) {
TF_RETURN_IF_ERROR(ReplaceInstruction(
- dot, reshape_if_necessary(AddReduce(
+ dot, reshape_if_necessary(add_reduce_in_f32(
multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(),
lhs_collapsing_dim)),
lhs_collapsing_dim))));
@@ -1061,7 +1074,8 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
const int n =
right_operand->shape().dimensions(1 - rhs_contracting_dimension);
- auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n});
+ auto memoized_shape =
+ ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
auto* memoized_inst = computation_->AddInstruction(
HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
dnums, dot->precision_config()));
@@ -1109,10 +1123,12 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
HloInstruction *lhs, *rhs;
CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
- // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
- // below.
- if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 ||
- ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) {
+ // Only optimize F32 or BF16 dot operations where the dot, rhs and lhs are
+ // rank 2 or below.
+ if ((dot->shape().element_type() != F32 &&
+ dot->shape().element_type() != BF16) ||
+ ShapeUtil::Rank(lhs->shape()) > 2 || ShapeUtil::Rank(rhs->shape()) > 2 ||
+ ShapeUtil::Rank(dot->shape()) > 2) {
return Status::OK();
}
@@ -2066,8 +2082,8 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
if (pad_literal == reduce_init_literal) {
return true;
}
- auto converted_pad_literal = pad_literal.ConvertToShape(
- reduce_init_value->shape(), /*round_f32_to_bf16=*/true);
+ auto converted_pad_literal =
+ pad_literal.ConvertToShape(reduce_init_value->shape());
if (!converted_pad_literal.ok()) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
index b864c372fa..9f8d0ee88b 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.h
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
// A pass which performs algebraic simplifications.
-class AlgebraicSimplifier : public HloPassInterface {
+class AlgebraicSimplifier : public HloModulePass {
public:
// Given shapes 'from_shape' and 'to_shape', determines if it is valid to
// bitcast from 'from_shape' to 'to_shape' after considering platform
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 3fc1ba2427..2047f894b4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -3233,17 +3233,18 @@ INSTANTIATE_TEST_CASE_P(
class DotStrengthReductionTest
: public AlgebraicSimplifierTest,
public ::testing::WithParamInterface<
- ::testing::tuple<int, int, int, bool, bool>> {};
+ ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
int m, k, n;
bool transpose_lhs, transpose_rhs;
- std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
+ PrimitiveType element_type;
+ std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
- Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
- Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
- Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
- Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
- Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
+ Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
+ Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
+ Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
+ Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
+ Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
HloComputation::Builder builder(TestName());
auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -3285,7 +3286,7 @@ INSTANTIATE_TEST_CASE_P(
DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
::testing::Values(1, 2), ::testing::Bool(),
- ::testing::Bool()));
+ ::testing::Bool(), ::testing::Values(F32, BF16)));
struct DotOfConcatTestSpec {
int64 m;
diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h
index 79d37f08d3..5b625bf3b9 100644
--- a/tensorflow/compiler/xla/service/batch_dot_simplification.h
+++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h
@@ -25,7 +25,7 @@ namespace xla {
// Normally these would live in the algebraic simplifier, but we want to run
// this to fixpoint (this pass reaches fixed point in one execution) before we
// run the DotDecomposer.
-class BatchDotSimplification : public HloPassInterface {
+class BatchDotSimplification : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override;
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h
index 76e32174f3..147f3ae7b6 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.h
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.h
@@ -26,7 +26,7 @@ namespace xla {
// A pass which rewrites batch norm operations into more operations. Breaking a
// big operation into smaller operations helps leverage our generic fusion
// logic.
-class BatchNormExpander : public HloPassInterface {
+class BatchNormExpander : public HloModulePass {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormExpander(bool rewrite_training_op = false,
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index aba0d9bb5b..f7ac8f5482 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
-using BatchNormExpanderTest = HloTestBase;
+using BatchNormExpanderTest = HloVerifiedTestBase;
// Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest, BatchNormTraining) {
@@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -126,13 +126,13 @@ ENTRY entry {
epsilon=0.001, feature_index=1, sharding={maximal device=1}
})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie());
- for (auto* instruction : module->entry_computation()->instructions()) {
+ for (auto* instruction : module().entry_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
index 5dcd31b83d..cb3d12f0bf 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h
@@ -31,7 +31,7 @@ namespace xla {
// optimization pipeline followed by a DCE pass. If other passes are needed
// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
// changed made by this pass.
-class BFloat16ConversionFolding : public HloPassInterface {
+class BFloat16ConversionFolding : public HloModulePass {
public:
explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h
index 30b6346312..f48e925823 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.h
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
// support BF16 input/output or mixed precision, according to the passed-in
// backend-specific BF16 support rules.
-class BFloat16Normalization : public HloPassInterface {
+class BFloat16Normalization : public HloModulePass {
public:
explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
@@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface {
// use mixed precision; it removes mixed precision even if the backend supports
// it. This pass is used to make the HLO module valid for other HLO passes which
// do not support mixed precision.
-class BFloat16MixedPrecisionRemoval : public HloPassInterface {
+class BFloat16MixedPrecisionRemoval : public HloModulePass {
public:
BFloat16MixedPrecisionRemoval() {}
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 545a6ecfb1..58f78f8e24 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -675,10 +675,8 @@ Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) {
continue;
}
if (!ShapeUtil::Equal(hlo->literal().shape(), hlo->shape())) {
- TF_ASSIGN_OR_RETURN(
- auto converted_literal,
- hlo->literal().ConvertToShape(hlo->shape(),
- /*round_f32_to_bf16=*/true));
+ TF_ASSIGN_OR_RETURN(auto converted_literal,
+ hlo->literal().ConvertToShape(hlo->shape()));
auto new_constant = computation->AddInstruction(
HloInstruction::CreateConstant(std::move(converted_literal)));
TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 1ee64971ab..6a62439f88 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -58,7 +58,7 @@ namespace xla {
// BFloat16ConversionFolding. If other passes are needed after this pass, run
// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
// pass.
-class BFloat16Propagation : public HloPassInterface {
+class BFloat16Propagation : public HloModulePass {
public:
explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 65fa951afe..34a7be0e9c 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -1064,6 +1064,19 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
// that seems to give the best results is lazy-best-fit, with all runs of
// alloc / free calls sorted in decreasing size order.
const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering();
+
+ // Returns a heap algorithm that chooses the best result from several
+ // algorithms.
+ auto get_heap_algorithm = [&](int64 alignment) {
+ auto algorithms =
+ absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>();
+ algorithms->push_back(absl::make_unique<DecreasingSizeRunsHeap>(
+ absl::make_unique<LazyBestFitHeap>(alignment)));
+ algorithms->push_back(
+ absl::make_unique<GlobalDecreasingSizeBestFitHeap>(alignment));
+ return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms));
+ };
+
if (run_whole_module_heap_simulation) {
// Run the heap simulation over the whole module. This reduces memory usage,
// since buffers for kCall, kWhile, and kConditional sub-computations are
@@ -1093,8 +1106,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>(
- absl::make_unique<LazyBestFitHeap>(alignment)),
+ HeapSimulator::Run(get_heap_algorithm(alignment),
assignment->module(), schedule,
assignment->points_to_analysis(),
assignment->buffer_size_, options));
@@ -1123,12 +1135,10 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
options.buffers_to_assign = &buffer_value_set;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
- HeapSimulator::Run(
- absl::make_unique<DecreasingSizeRunsHeap>(
- absl::make_unique<LazyBestFitHeap>(alignment)),
- *computation, HloInstructionSequence(*instruction_sequence),
- assignment->points_to_analysis(), assignment->buffer_size_,
- options));
+ HeapSimulator::Run(get_heap_algorithm(alignment), *computation,
+ HloInstructionSequence(*instruction_sequence),
+ assignment->points_to_analysis(),
+ assignment->buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment,
single_colored_set.first);
}
diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h
index c5cd88b9ea..08c4aff4f7 100644
--- a/tensorflow/compiler/xla/service/call_inliner.h
+++ b/tensorflow/compiler/xla/service/call_inliner.h
@@ -25,7 +25,7 @@ namespace xla {
// For every kCall operation in the main computation, we inline the body of the
// called function, and proceed recursively.
-class CallInliner : public HloPassInterface {
+class CallInliner : public HloModulePass {
public:
using InlinedInstructionMap =
std::unordered_map<HloInstruction*, HloInstruction*>;
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 5d85a3f173..e6b5665435 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -40,7 +40,7 @@ namespace {
// Tests for call inlining that are most tractable at the HLO level (vs
// ComputationBuilder API in call_test.cc).
-using CallInlinerTest = HloTestBase;
+using CallInlinerTest = HloVerifiedTestBase;
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to
@@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
@@ -92,6 +92,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
HloComputation::Builder call_false_builder(TestName() + ".call_false");
call_false_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, pred, "param"));
+ call_false_builder.AddInstruction(
HloInstruction::CreateCall(pred, {}, false_computation));
HloComputation* call_false =
module->AddEmbeddedComputation(call_false_builder.Build());
@@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(
computation->root_instruction()->while_condition()->root_instruction(),
@@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
}
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index e5a6c28478..96bd2616f5 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -97,7 +97,7 @@ CompileOnlyService::CompileAheadOfTime(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
HloModule::CreateFromProto(instance.computation, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*hlo_module));
hlo_modules.push_back(std::move(hlo_module));
}
diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h
index 3de50cbd7f..2223ad6753 100644
--- a/tensorflow/compiler/xla/service/conditional_simplifier.h
+++ b/tensorflow/compiler/xla/service/conditional_simplifier.h
@@ -25,7 +25,7 @@ namespace xla {
// HLO pass that removes kConditional with a constant predicate, replacing them
// with their true or false computation as appropriate.
-class ConditionalSimplifier : public HloPassInterface {
+class ConditionalSimplifier : public HloModulePass {
public:
absl::string_view name() const override { return "simplify-conditional"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
index 498894737f..ce0138e56f 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which rewrites convolutions with feature_group_count > 1 into
// convolutions with feature_group_count = 1.
-class ConvolutionFeatureGroupConverter : public HloPassInterface {
+class ConvolutionFeatureGroupConverter : public HloModulePass {
public:
ConvolutionFeatureGroupConverter() {}
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index d308f6bc84..c097089e30 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -43,7 +43,7 @@ namespace xla {
// (3) The buffer set of the root instruction of the entry computation must be
// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
// InstructionAliasSet::IsDistinct return true.
-class CopyInsertion : public HloPassInterface {
+class CopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 8cc522a59e..b7103118ac 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/compiler/xla/service/cpu:cpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/stream_executor",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -180,6 +181,7 @@ cc_library(
":runtime_conv2d_mkl",
":runtime_fft",
":runtime_fork_join",
+ ":runtime_key_value_sort",
":runtime_matmul",
":runtime_matmul_mkl",
":runtime_single_threaded_conv2d",
@@ -461,12 +463,15 @@ cc_library(
],
copts = runtime_copts(),
deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
+ "//tensorflow/stream_executor",
+ "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
)
@@ -624,6 +629,18 @@ cc_library(
)
cc_library(
+ name = "runtime_key_value_sort",
+ srcs = ["runtime_key_value_sort.cc"],
+ hdrs = ["runtime_key_value_sort.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
name = "runtime_fork_join",
srcs = ["runtime_fork_join.cc"],
hdrs = ["runtime_fork_join.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
index 59437e88af..becee3f81f 100644
--- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -31,7 +31,7 @@ namespace cpu {
// called canonical convolutions). This pass expands non-canonical convolutions
// into reshapes and canonical convolutions, so that these non-canonical
// convolutions can run faster.
-class ConvCanonicalization : public HloPassInterface {
+class ConvCanonicalization : public HloModulePass {
public:
explicit ConvCanonicalization(
const TargetMachineFeatures* target_machine_features)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
index d49f7d7cc2..076235f887 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h
@@ -30,7 +30,7 @@ namespace xla {
//
// TODO(b/62548313): Remove this when buffer assignment is smarter
// (module-scoped).
-class CpuCopyInsertion : public HloPassInterface {
+class CpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
index 6af724b2a5..a39a9d4724 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h
@@ -23,7 +23,7 @@ namespace xla {
// This pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the CPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
-class CpuHloSupportChecker : public HloPassInterface {
+class CpuHloSupportChecker : public HloModulePass {
public:
CpuHloSupportChecker() = default;
~CpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 8a44c384bb..20cf855735 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -17,19 +17,29 @@ limitations under the License.
#include <functional>
+#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
namespace cpu {
namespace runtime {
-XfeedManager* GetXfeedManager() {
- static XfeedManager* manager = new XfeedManager;
- return manager;
+XfeedManager* GetXfeedManager(int device_ordinal) {
+ static tensorflow::gtl::FlatMap<int, XfeedManager*>* managers =
+ new tensorflow::gtl::FlatMap<int, XfeedManager*>();
+ static absl::Mutex* mutex = new absl::Mutex();
+
+ absl::MutexLock lock(mutex);
+ auto it = managers->find(device_ordinal);
+ if (it == managers->end()) {
+ it = managers->emplace(device_ordinal, new XfeedManager()).first;
+ }
+ return it->second;
}
extern const char* const kEigenMatMulF16SymbolName =
@@ -74,6 +84,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName =
"__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation";
extern const char* const kParallelForkJoinSymbolName =
"__xla_cpu_runtime_ParallelForkJoin";
+extern const char* const kKeyValueSortPREDSymbolName =
+ "__xla_cpu_runtime_KeyValueSortPRED";
+extern const char* const kKeyValueSortS8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS8";
+extern const char* const kKeyValueSortU8SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU8";
+extern const char* const kKeyValueSortS16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS16";
+extern const char* const kKeyValueSortU16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU16";
+extern const char* const kKeyValueSortF16SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF16";
+extern const char* const kKeyValueSortS32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS32";
+extern const char* const kKeyValueSortU32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU32";
+extern const char* const kKeyValueSortF32SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF32";
+extern const char* const kKeyValueSortS64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortS64";
+extern const char* const kKeyValueSortU64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortU64";
+extern const char* const kKeyValueSortF64SymbolName =
+ "__xla_cpu_runtime_KeyValueSortF64";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
} // namespace runtime
@@ -94,14 +128,18 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) {
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
-__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
- const void* shape,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "AcquireInfeedBufferForDequeue: "
- << ShapeString(shape, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_AcquireInfeedBufferForDequeue(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "AcquireInfeedBufferForDequeue: "
+ << ShapeString(shape, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->infeed()->BlockingDequeueBuffer();
@@ -114,15 +152,18 @@ __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
-__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
- void* buffer_ptr,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "ReleaseInfeedBufferAfterDeque: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "ReleaseInfeedBufferAfterDeque: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
xla::StatusOr<xla::Shape> shape =
xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
@@ -130,14 +171,18 @@ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void*
-__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "AcquireOutfeedBufferForPopulation: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "AcquireOutfeedBufferForPopulation: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
// Wait until there's a buffer to dequeue.
xla::cpu::runtime::XfeedBuffer* buffer =
xfeed->outfeed()->BlockingDequeueBuffer();
@@ -150,15 +195,18 @@ __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length,
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
-__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length,
- void* buffer_ptr,
- const void* shape_ptr,
- xla::int32 shape_length) {
- if (VLOG_IS_ON(2)) {
- LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: "
- << ShapeString(shape_ptr, shape_length);
- }
- xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager();
+__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) {
+ int device_ordinal =
+ run_options ? run_options->stream()->parent()->device_ordinal() : 0;
+
+ VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: "
+ << ShapeString(shape_ptr, shape_length) << " on stream executor "
+ << device_ordinal;
+
+ xla::cpu::runtime::XfeedManager* xfeed =
+ xla::cpu::runtime::GetXfeedManager(device_ordinal);
xla::StatusOr<xla::Shape> shape =
xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length);
xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
index aa0e967123..b2e760a224 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -26,6 +26,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
+#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
#include "tensorflow/compiler/xla/types.h"
@@ -63,13 +64,26 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName;
extern const char* const kAcquireOutfeedBufferForPopulationSymbolName;
extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
+extern const char* const kKeyValueSortPREDSymbolName;
+extern const char* const kKeyValueSortS8SymbolName;
+extern const char* const kKeyValueSortU8SymbolName;
+extern const char* const kKeyValueSortS16SymbolName;
+extern const char* const kKeyValueSortU16SymbolName;
+extern const char* const kKeyValueSortF16SymbolName;
+extern const char* const kKeyValueSortS32SymbolName;
+extern const char* const kKeyValueSortU32SymbolName;
+extern const char* const kKeyValueSortF32SymbolName;
+extern const char* const kKeyValueSortS64SymbolName;
+extern const char* const kKeyValueSortU64SymbolName;
+extern const char* const kKeyValueSortF64SymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
extern const char* const kXlaCpuRuntimeSymbolNamePrefix;
-// Returns the infeed manager used by the CPU runtime.
-XfeedManager* GetXfeedManager();
+// Returns the infeed manager used by the CPU runtime for the CPU device
+// `device_ordinal`. Note the device ordinal does not name a CPU
+XfeedManager* GetXfeedManager(int device_ordinal);
} // namespace runtime
} // namespace cpu
@@ -77,6 +91,18 @@ XfeedManager* GetXfeedManager();
extern "C" {
+// Some things common to all of the runtime entry points below:
+//
+// * The shape pointer and shape_length reflect values that can be deserialized
+// via llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass
+// reified type information from the generated program to the runtime, which
+// helps check the type safety and contract for the emitted-code/runtime
+// communication.
+//
+// * run_options is used to look up the device ordinal for the stream executor
+// we're executing under. If it is null the device ordinal is assumed to be
+// 0 (this behavior helps in writing tests).
+
// Note: in the runtime entry points below, the shape pointer and shape_length
// reflect values that can be deserialized via
// llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified
@@ -89,7 +115,8 @@ extern "C" {
// the length would be more exact, but the length check is chosen as a
// tradeoff between error checking and speed/simplicity.
extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- xla::int32 buffer_length, const void* shape, xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape, xla::int32 shape_length);
// Relinquishes the next infeed buffer that was returned by
// __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call
@@ -104,13 +131,14 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
// implemented we will add support for multiple outstanding buffers
// that can be returned out of order.
extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
// Blocks until the next outfeed buffer is available to be populated, then
// returns it.
extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ const void* shape_ptr, xla::int32 shape_length);
// Relinquishes the outfeed buffer after it has been populated.
// buffer_ptr must have been previously returned by
@@ -122,8 +150,8 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
// acquired, i.e., there may only be one outstanding outfeed buffer in
// use by the runtime.
extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr,
- xla::int32 shape_length);
+ const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length,
+ void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length);
} // extern "C"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 5519a43b2f..1cc2844470 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/stream_executor/stream_executor.h"
namespace xla {
@@ -128,7 +129,8 @@ Status CpuTransferManager::TransferLiteralToInfeed(
buffers.push_back(buffer);
}
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers);
cleanup.release();
@@ -141,7 +143,8 @@ Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor,
TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
TransferBufferToInfeedInternal(executor, size, source));
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer});
return Status::OK();
@@ -265,7 +268,8 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal(
buffer_pointers.push_back(b.get());
}
- cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed_manager =
+ cpu::runtime::GetXfeedManager(executor->device_ordinal());
xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers);
VLOG(2) << "Waiting for buffer to be notified as populated.";
std::vector<Shape> outfed_shapes;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index df8c2a636b..c3e8020783 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -404,13 +404,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value * shape_ptr,
llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_));
- // The signature of the acquire infeed buffer function is:
- //
- // (void*)(int32 length);
llvm::Type* int32_type = b_.getInt32Ty();
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::FunctionType* acquire_type = llvm::FunctionType::get(
- i8_ptr_type, {int32_type, i8_ptr_type, int32_type},
+ i8_ptr_type,
+ {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
+ /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type},
/*isVarArg=*/false);
llvm::Function* acquire_func;
@@ -423,11 +422,11 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
}
acquire_func->setCallingConv(llvm::CallingConv::C);
- // The signature of the release infeed buffer function is:
- //
- // (void)(int32 length, void* buffer);
llvm::FunctionType* release_type = llvm::FunctionType::get(
- b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type},
+ b_.getVoidTy(),
+ {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type,
+ /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type,
+ /*shape_length*/ int32_type},
/*isVarArg=*/false);
llvm::Function* release_func;
@@ -444,9 +443,9 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
- llvm::Value* acquired_pointer =
- Call(acquire_func,
- {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)});
+ llvm::Value* acquired_pointer = Call(
+ acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
+ shape_ptr, b_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
@@ -458,8 +457,8 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
/*SrcAlign=*/1, length_32);
}
- Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr,
- b_.getInt32(shape_length)});
+ Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32),
+ acquired_pointer, shape_ptr, b_.getInt32(shape_length)});
return Status::OK();
}
@@ -495,8 +494,150 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
}
Status IrEmitter::HandleSort(HloInstruction* sort) {
- // TODO(b/26783907): Implement sort on CPU.
- return Unimplemented("Sort is not implemented on CPU.");
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort));
+ auto keys = sort->operand(0);
+ auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
+ ShapeIndex keys_shape_index({});
+ ShapeIndex values_shape_index({});
+ if (values != nullptr) {
+ keys_shape_index = ShapeIndex({0});
+ values_shape_index = ShapeIndex({1});
+ }
+ auto keys_destination = GetAllocationSlice(*sort, keys_shape_index);
+ auto keys_destination_address =
+ EmitBufferPointer(keys_destination, keys->shape());
+ auto values_destination = GetAllocationSlice(*sort, values_shape_index);
+ llvm::Value* values_destination_address = nullptr;
+
+ // The sort is implemented in-place, therefore we first copy the operand
+ // buffer to the output buffer if they are not the same.
+ if (keys_destination != GetAllocationSlice(*keys)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(keys);
+ int64 keys_size = ByteSizeOf(keys->shape());
+ MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, keys_size);
+ }
+ if (values != nullptr) {
+ values_destination_address =
+ EmitBufferPointer(values_destination, values->shape());
+ if (values_destination != GetAllocationSlice(*values)) {
+ int64 primitive_type_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type());
+ auto source_buffer = GetEmittedValueFor(values);
+ int64 values_size = ByteSizeOf(values->shape());
+ MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size,
+ source_buffer,
+ /*SrcAlign=*/primitive_type_size, values_size);
+ }
+ }
+
+ // Normalize the shape and the dimension to sort.
+ Shape normalized_keys_shape =
+ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
+ keys->shape());
+ int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical(
+ keys->shape().layout())[sort->dimensions(0)];
+
+ int64 sort_dimension_elements =
+ normalized_keys_shape.dimensions(physical_dimension_to_sort);
+ int64 higher_dimensions = 1;
+ for (int64 i = 0; i < physical_dimension_to_sort; ++i) {
+ higher_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+ int64 lower_dimensions = 1;
+ for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1;
+ i > physical_dimension_to_sort; --i) {
+ lower_dimensions *= normalized_keys_shape.dimensions(i);
+ }
+
+ PrimitiveType keys_type = keys->shape().element_type();
+ const char* fn_name = nullptr;
+ llvm::Type* keys_native_type = nullptr;
+ switch (keys_type) {
+ case PRED:
+ fn_name = runtime::kKeyValueSortPREDSymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S8:
+ fn_name = runtime::kKeyValueSortS8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case U8:
+ fn_name = runtime::kKeyValueSortU8SymbolName;
+ keys_native_type = b_.getInt8PtrTy();
+ break;
+ case S16:
+ fn_name = runtime::kKeyValueSortS16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case U16:
+ fn_name = runtime::kKeyValueSortU16SymbolName;
+ keys_native_type = b_.getInt16Ty()->getPointerTo();
+ break;
+ case F16:
+ fn_name = runtime::kKeyValueSortF16SymbolName;
+ keys_native_type = b_.getHalfTy()->getPointerTo();
+ break;
+ case S32:
+ fn_name = runtime::kKeyValueSortS32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case U32:
+ fn_name = runtime::kKeyValueSortU32SymbolName;
+ keys_native_type = b_.getInt32Ty()->getPointerTo();
+ break;
+ case F32:
+ fn_name = runtime::kKeyValueSortF32SymbolName;
+ keys_native_type = b_.getFloatTy()->getPointerTo();
+ break;
+ case S64:
+ fn_name = runtime::kKeyValueSortS64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case U64:
+ fn_name = runtime::kKeyValueSortU64SymbolName;
+ keys_native_type = b_.getInt64Ty()->getPointerTo();
+ break;
+ case F64:
+ fn_name = runtime::kKeyValueSortF64SymbolName;
+ keys_native_type = b_.getDoubleTy()->getPointerTo();
+ break;
+ default:
+ return Unimplemented(
+ "Element type %s not supported in the Sort op on CPU.",
+ PrimitiveType_Name(keys_type));
+ }
+
+ llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get(
+ b_.getVoidTy(),
+ {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(),
+ b_.getInt8PtrTy(), b_.getInt32Ty()},
+ /*isVarArg=*/false);
+ auto* key_value_sort_func = llvm::cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, key_value_sort_type));
+ key_value_sort_func->setCallingConv(llvm::CallingConv::C);
+ key_value_sort_func->setDoesNotThrow();
+ key_value_sort_func->setOnlyAccessesArgMemory();
+ Call(key_value_sort_func,
+ {PointerCast(keys_destination_address, keys_native_type),
+ b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements),
+ b_.getInt64(lower_dimensions),
+ values != nullptr
+ ? PointerCast(values_destination_address, b_.getInt8PtrTy())
+ : llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType(
+ values->shape().element_type())
+ : 0)});
+
+ if (values != nullptr) {
+ llvm_ir::EmitTuple(GetIrArrayFor(sort),
+ {keys_destination_address, values_destination_address},
+ &b_, module_);
+ }
+ return Status::OK();
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 3df99464ba..daafef4eb3 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -163,6 +163,12 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status Preprocess(HloInstruction* hlo) override;
Status Postprocess(HloInstruction* hlo) override;
+ // A convenient helper for calling BufferAssignment::GetUniqueSlice.
+ BufferAllocation::Slice GetAllocationSlice(
+ const HloInstruction& hlo, const ShapeIndex& index = {}) const {
+ return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie();
+ }
+
private:
// Private helper to initialize an IR function for the computation.
void InitializeIrFunction(const string& function_name);
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index b4c0c09ec0..ede7f433ca 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -142,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
+ opcode == HloOpcode::kSort ||
(opcode == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_)) ||
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index a99cd99c14..3822d5300e 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -60,7 +60,7 @@ class ParallelTaskAssignment {
// own embedded computation, which is compiled as a parallel compute function,
// and which is invoked from a kCall instruction that is lowered in codegen to
// a runtime parallel fork/join call.
-class ParallelTaskAssigner : public HloPassInterface {
+class ParallelTaskAssigner : public HloModulePass {
public:
// 'max_parallelism': the maximum parallel task count per instruction.
// 'shape_size': shape size function used by HloCostAnalysis during parallel
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
new file mode 100644
index 0000000000..e0e7deb98e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
@@ -0,0 +1,236 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
+
+#include <algorithm>
+#include <cmath>
+#include <cstring>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace {
+using tensorflow::int16;
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::int8;
+using tensorflow::uint16;
+using tensorflow::uint32;
+using tensorflow::uint64;
+using tensorflow::uint8;
+
+template <typename KeyType>
+void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements);
+}
+
+// For floating point numbers, we want a total order comparator. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0. Also we want to have a stable sort, so if the keys are the
+// same, we compare the index values.
+template <typename KeyType>
+bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) {
+ bool lhs_is_negative = std::signbit(lhs);
+ bool rhs_is_negative = std::signbit(rhs);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(lhs);
+ bool rhs_nan = std::isnan(rhs);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
+ }
+ if (lhs != rhs) {
+ return lhs < rhs;
+ }
+ return lhs_index < rhs_index;
+}
+
+template <>
+void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<double, int64>& lhs,
+ const std::pair<double, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<float, int64>& lhs,
+ const std::pair<float, int64>& rhs) -> bool {
+ return LessThan(lhs.first, lhs.second, rhs.first, rhs.second);
+ });
+}
+
+template <>
+void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort,
+ int64 num_elements) {
+ std::sort(row_to_sort, row_to_sort + num_elements,
+ [](const std::pair<Eigen::half, int64>& lhs,
+ const std::pair<Eigen::half, int64>& rhs) -> bool {
+ return LessThan(
+ Eigen::half_impl::half_to_float(lhs.first), lhs.second,
+ Eigen::half_impl::half_to_float(rhs.first), rhs.second);
+ });
+}
+
+template <typename KeyType>
+void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ // High-level idea of the iteration/sorting logic:
+ // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the
+ // dimension to sort, c is the product of the more minor dimensions (set to 1
+ // if b is the most minor dimension), and a is the product of the more major
+ // dimensions (set to 1 if b is the most major dimension). There are a * c
+ // many rows that we need to sort. We iterate through these, calculate a
+ // 'base_offset' value which points to the first element in that row, and add
+ // i * c for accessing the 'i'-th element in that row.
+
+ int64 sort_dimension_elements = b;
+ int64 num_iteration_elements = a * c;
+ int64 sort_dimension_offset = c;
+
+ std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort(
+ new std::pair<KeyType, int64>[sort_dimension_elements]);
+ std::unique_ptr<std::string[]> reordered_values(
+ new std::string[sort_dimension_elements]);
+ for (int64 index = 0; index < num_iteration_elements; ++index) {
+ // 'index' can be split into two values which index into the 'c' dimension
+ // and the 'a' dimension, respectively. 'index' % 'c' is the index into the
+ // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When
+ // calculating the base offset, we need to multiply the index into the 'a'
+ // dimension with 'b' * 'c'.
+ // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'.
+ int64 base_offset =
+ index % sort_dimension_offset +
+ (index - index % sort_dimension_offset) * sort_dimension_elements;
+ // TODO(b/26783907): We could define a custom iterator class that references
+ // both arrays. Then we could avoid the intermediate copy. However this
+ // would become more complicated, and it is not clear if the benefit is high
+ // enough.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ row_to_sort[i] =
+ std::make_pair(keys[base_offset + i * sort_dimension_offset], i);
+ }
+ KeyValueSort(row_to_sort.get(), sort_dimension_elements);
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first;
+ }
+ if (values == nullptr) {
+ continue;
+ }
+
+ // Reorder the values according to the order defined by the keys.
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index =
+ (base_offset + row_to_sort[i].second * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+
+ reordered_values[i] = std::string(values + memory_index,
+ values_primitive_type_size_in_bytes);
+ }
+ for (int64 i = 0; i < sort_dimension_elements; ++i) {
+ int64 memory_index = (base_offset + i * sort_dimension_offset) *
+ values_primitive_type_size_in_bytes;
+ memcpy(values + memory_index, reordered_values[i].c_str(),
+ values_primitive_type_size_in_bytes);
+ }
+ }
+}
+} // namespace
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8(
+ int8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8(
+ uint8* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16(
+ int16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16(
+ uint16* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32(
+ int32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32(
+ uint32* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64(
+ int64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64(
+ uint64* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, int64 a, int64 b, int64 c, char* values,
+ int32 values_primitive_type_size_in_bytes) {
+ KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
new file mode 100644
index 0000000000..28e35e82c1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h
@@ -0,0 +1,88 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b'
+// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr.
+// If 'values' is not nullptr, the elements in 'values' are reordered in such a
+// way that if the element at index 'i' in 'keys' was moved to index 'j', the
+// element at index 'i' in 'values' is also moved to index 'j' (which means that
+// the same elements correspond to each other as before).
+extern void __xla_cpu_runtime_KeyValueSortPRED(
+ bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS8(
+ tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU8(
+ tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS16(
+ tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU16(
+ tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF16(
+ Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS32(
+ tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU32(
+ tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF32(
+ float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortS64(
+ tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortU64(
+ tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b,
+ tensorflow::int64 c, char* values,
+ tensorflow::int32 values_primitive_type_size_in_bytes);
+
+extern void __xla_cpu_runtime_KeyValueSortF64(
+ double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c,
+ char* values, tensorflow::int32 values_primitive_type_size_in_bytes);
+}
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index bf98064647..9ec0c8f657 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
@@ -202,6 +203,18 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
+ REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index c55206eee7..4b129c95d4 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -180,3 +180,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cpu_key_value_sort_test",
+ srcs = ["cpu_key_value_sort_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
new file mode 100644
index 0000000000..3934c03a04
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+class CpuKeyValueSortTest : public CpuCodegenTest {};
+
+TEST_F(CpuKeyValueSortTest, SortR1) {
+ const string hlo_text = R"(
+HloModule KeyValueSort
+
+ENTRY main {
+ a = f32[10] parameter(0)
+
+ ROOT result = f32[10] sort(f32[10] a), dimensions={0}
+}
+)";
+
+ string filecheck_pattern = R"(
+CHECK: call void @__xla_cpu_runtime_KeyValueSort
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ CpuAotCompilationOptions options{
+ /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"",
+ /*entry_point_name=*/"entry",
+ /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static};
+
+ CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern,
+ /*match_optimized_ir=*/true);
+}
+
+} // namespace
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
index 8fe65f488a..cc38b81455 100644
--- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc
@@ -66,9 +66,9 @@ void ProcessNextBuffer(int32 length) {
auto shape = ShapeUtil::MakeShape(U8, {length});
string bytes = shape.SerializeAsString();
void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
- length, bytes.data(), bytes.size());
- __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer,
- bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, bytes.data(), bytes.size());
+ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size());
}
// Performs the acquire/release sequence on the outfeed, as the generated CPU
@@ -76,16 +76,16 @@ void ProcessNextBuffer(int32 length) {
void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) {
string bytes = shape.SerializeAsString();
void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(
- length, bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, bytes.data(), bytes.size());
__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
- length, buffer, bytes.data(), bytes.size());
+ /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size());
}
TEST_F(InfeedManagerTest, SingleThreadedSequential) {
TestInfeedBuffer* a = new TestInfeedBuffer(64);
TestInfeedBuffer* b = new TestInfeedBuffer(32);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->infeed()->EnqueueBuffersAtomically({a});
xfeed->infeed()->EnqueueBuffersAtomically({b});
@@ -97,7 +97,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
TestInfeedBuffer* a = new TestInfeedBuffer(64);
TestInfeedBuffer* b = new TestInfeedBuffer(32);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->infeed()->EnqueueBuffersAtomically({a});
ProcessNextBuffer(a->length());
@@ -108,7 +108,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
TEST_F(InfeedManagerTest, MultiThreaded) {
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
const int32 length = 64;
@@ -130,7 +130,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) {
TEST_F(InfeedManagerTest, OutfeedWrongShape) {
TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false);
- cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager();
+ cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0);
xfeed->outfeed()->EnqueueBuffersAtomically({b});
ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33}));
diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h
index c326beb899..aaa41fc4fe 100644
--- a/tensorflow/compiler/xla/service/defuser.h
+++ b/tensorflow/compiler/xla/service/defuser.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which replaces all fusion instructions with the equivalent un-fused
// instructions.
-class Defuser : public HloPassInterface {
+class Defuser : public HloModulePass {
public:
Defuser() {}
~Defuser() override {}
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index ba2a674d9a..b3549acfc2 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -24,7 +24,7 @@ namespace xla {
namespace {
// Pass which strips control dependencies from all instructions in the module.
-class ControlDepRemover : public HloPassInterface {
+class ControlDepRemover : public HloModulePass {
public:
ControlDepRemover() = default;
absl::string_view name() const override { return "control-dep-remover"; }
diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h
index 7be70add2f..46dcc3a438 100644
--- a/tensorflow/compiler/xla/service/despecializer.h
+++ b/tensorflow/compiler/xla/service/despecializer.h
@@ -30,7 +30,7 @@ namespace xla {
//
// Current despecialization passes are Defuser, ImplicitBroadcastRemover,
// and BFloat16MixedPrecisionRemoval.
-class Despecializer : public HloPassInterface {
+class Despecializer : public HloModulePass {
public:
Despecializer();
absl::string_view name() const override { return "despecializer"; }
diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h
index fc38e31700..40e7a3b4c2 100644
--- a/tensorflow/compiler/xla/service/dot_decomposer.h
+++ b/tensorflow/compiler/xla/service/dot_decomposer.h
@@ -23,7 +23,7 @@ namespace xla {
// DotDecomposer is a pass which decomposes batch Dot operations into a
// sequence of smaller (R2) Dot operations.
-class DotDecomposer : public HloPassInterface {
+class DotDecomposer : public HloModulePass {
public:
// Decomposes batch Dot operations when 'decompose_batch_dot' is true.
DotDecomposer(bool decompose_batch_dot = true)
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 4bb1e071d8..515267edd7 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -847,29 +847,34 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) {
- if (prim_type != F32) {
- // TODO(b/34339814): Implement inverse erf for F64.
+ if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
return Unimplemented(
"Inverse erf is only implemented for element "
- "type F32.");
+ "types F16, F32 and F64.");
}
- auto getFloat = [&](const float f) {
- return llvm::ConstantFP::get(b_->getFloatTy(), f);
+
+ // Upcast half to float.
+ if (prim_type == F16) {
+ x = b_->CreateFPExt(x, b_->getFloatTy());
+ }
+
+ auto get_float = [&](const double f) {
+ return llvm::ConstantFP::get(x->getType(), f);
};
- auto multiply_add = [&](absl::Span<const float> coefficients,
+ auto multiply_add = [&](absl::Span<const double> coefficients,
llvm::Value* w) {
- llvm::Value* p = getFloat(coefficients.front());
+ llvm::Value* p = get_float(coefficients.front());
coefficients.remove_prefix(1);
for (float coefficient : coefficients) {
- p = FAdd(FMul(p, w), getFloat(coefficient));
+ p = FAdd(FMul(p, w), get_float(coefficient));
}
return p;
};
// Approximation for inverse error function from
// Giles, M., "Approximating the erfinv function".
- // The approximation has the form:
- // w = log((1-x)*(1+x))
+ // The approximation has the form (float version):
+ // w = -log((1-x)*(1+x))
// if ( w < 5 ) {
// w = w - 2.5
// p = sum_{i=1}^n lq[i]*w^i
@@ -879,46 +884,124 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
// }
// return p*x
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::log, {b_->getFloatTy()});
+ module_, llvm::Intrinsic::log, {x->getType()});
- llvm::Value* w = FNeg(
- Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = FNeg(Call(
+ logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))}));
llvm::Value* p_addr =
- llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
+ llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_);
+
+ if (prim_type == F16 || prim_type == F32) {
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_);
+ // Handle true BB.
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(2.5f));
+ absl::Span<const double> lq{
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ llvm::Value* p = multiply_add(lq, lw);
+ Store(p, p_addr);
+ }
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
- // Handle true BB.
- SetToFirstInsertPoint(if_data.true_block, b_);
- {
- llvm::Value* lw = FSub(w, getFloat(2.5f));
- absl::Span<const float> lq{
- 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
- -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
- -0.00417768164f, 0.246640727f, 1.50140941f};
- llvm::Value* p = multiply_add(lq, lw);
- Store(p, p_addr);
- }
+ // Handle false BB.
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f));
+ absl::Span<const double> gq{
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+ llvm::Value* p = multiply_add(gq, gw);
+ Store(p, p_addr);
+ }
- // Handle false BB.
- SetToFirstInsertPoint(if_data.false_block, b_);
- {
- llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
-
- llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f));
- absl::Span<const float> gq{
- -0.000200214257f, 0.000100950558f, 0.00134934322f,
- -0.00367342844f, 0.00573950773f, -0.0076224613f,
- 0.00943887047f, 1.00167406f, 2.83297682f};
- llvm::Value* p = multiply_add(gq, gw);
- Store(p, p_addr);
- }
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ } else {
+ DCHECK(prim_type == F64);
+
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_);
+
+ SetToFirstInsertPoint(if_data.true_block, b_);
+ {
+ llvm::Value* lw = FSub(w, get_float(3.125));
+ absl::Span<const double> c{
+ -3.6444120640178196996e-21, -1.685059138182016589e-19,
+ 1.2858480715256400167e-18, 1.115787767802518096e-17,
+ -1.333171662854620906e-16, 2.0972767875968561637e-17,
+ 6.6376381343583238325e-15, -4.0545662729752068639e-14,
+ -8.1519341976054721522e-14, 2.6335093153082322977e-12,
+ -1.2975133253453532498e-11, -5.4154120542946279317e-11,
+ 1.051212273321532285e-09, -4.1126339803469836976e-09,
+ -2.9070369957882005086e-08, 4.2347877827932403518e-07,
+ -1.3654692000834678645e-06, -1.3882523362786468719e-05,
+ 0.0001867342080340571352, -0.00074070253416626697512,
+ -0.0060336708714301490533, 0.24015818242558961693,
+ 1.6536545626831027356};
+ llvm::Value* p = multiply_add(c, lw);
+ Store(p, p_addr);
+ }
- SetToFirstInsertPoint(if_data.after_block, b_);
+ SetToFirstInsertPoint(if_data.false_block, b_);
+ llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse(
+ FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_);
+ SetToFirstInsertPoint(if_data_second.true_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25));
+ absl::Span<const double> t1{
+ 2.2137376921775787049e-09, 9.0756561938885390979e-08,
+ -2.7517406297064545428e-07, 1.8239629214389227755e-08,
+ 1.5027403968909827627e-06, -4.013867526981545969e-06,
+ 2.9234449089955446044e-06, 1.2475304481671778723e-05,
+ -4.7318229009055733981e-05, 6.8284851459573175448e-05,
+ 2.4031110387097893999e-05, -0.0003550375203628474796,
+ 0.00095328937973738049703, -0.0016882755560235047313,
+ 0.0024914420961078508066, -0.0037512085075692412107,
+ 0.005370914553590063617, 1.0052589676941592334,
+ 3.0838856104922207635};
+ llvm::Value* p = multiply_add(t1, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data_second.false_block, b_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()});
+
+ llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0));
+ absl::Span<const double> t2{
+ -2.7109920616438573243e-11, -2.5556418169965252055e-10,
+ 1.5076572693500548083e-09, -3.7894654401267369937e-09,
+ 7.6157012080783393804e-09, -1.4960026627149240478e-08,
+ 2.9147953450901080826e-08, -6.7711997758452339498e-08,
+ 2.2900482228026654717e-07, -9.9298272942317002539e-07,
+ 4.5260625972231537039e-06, -1.9681778105531670567e-05,
+ 7.5995277030017761139e-05, -0.00021503011930044477347,
+ -0.00013871931833623122026, 1.0103004648645343977,
+ 4.8499064014085844221};
+ llvm::Value* p = multiply_add(t2, gw);
+ Store(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ }
llvm::Value* p = Load(p_addr);
- return FMul(p, x);
+ x = FMul(p, x);
+ // Trunc back to half if needed.
+ if (prim_type == F16) {
+ x = b_->CreateFPTrunc(x, b_->getHalfTy());
+ }
+ return x;
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type,
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h
index 3cccec9862..986970f886 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph.h
+++ b/tensorflow/compiler/xla/service/flatten_call_graph.h
@@ -26,7 +26,7 @@ namespace xla {
// Flattening associates each call site with a unique computation (for
// sequential calling contexts) This simplifies buffer assignment and
// points-to analysis (see b/36865746 for details).
-class FlattenCallGraph : public HloPassInterface {
+class FlattenCallGraph : public HloModulePass {
public:
absl::string_view name() const override { return "flatten-call-graph"; }
diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h
index 7bd9ea5984..2b39359aae 100644
--- a/tensorflow/compiler/xla/service/gather_expander.h
+++ b/tensorflow/compiler/xla/service/gather_expander.h
@@ -23,7 +23,7 @@ namespace xla {
// This pass rewrites gather operations into (roughly) while loops of dynamic
// slices. This lets backends that don't support gather directly to
// nevertheless have a minimum level of support.
-class GatherExpander : public HloPassInterface {
+class GatherExpander : public HloModulePass {
public:
absl::string_view name() const override { return "gather_expander"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 64b9683628..51968d13d4 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -68,9 +68,7 @@ cc_library(
# srcs = [
# "partition_assignment_test.cc",
# ],
-# tags = [
-# "requires-gpu-sm35",
-# ],
+# tags = tf_cuda_tests_tags(),
# deps = [
# ":partition_assignment",
# "//tensorflow/core:stream_executor_no_cuda",
@@ -373,7 +371,6 @@ cc_library(
hdrs = ["ir_emission_utils.h"],
deps = [
":backend_configs",
- ":cudnn_convolution_runner",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
@@ -414,6 +411,8 @@ cc_library(
srcs = ["cudnn_convolution_runner.cc"],
hdrs = ["cudnn_convolution_runner.h"],
deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
@@ -422,8 +421,10 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -432,6 +433,7 @@ cc_library(
srcs = ["cudnn_convolution_rewriter.cc"],
hdrs = ["cudnn_convolution_rewriter.h"],
deps = [
+ ":backend_configs",
":ir_emission_utils",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:util",
@@ -596,14 +598,11 @@ cc_library(
hdrs = ["pad_for_tensor_cores.h"],
deps = [
":ir_emission_utils",
- "//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo_creation_utils",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
- "//tensorflow/compiler/xla/service:shape_inference",
],
)
@@ -656,6 +655,7 @@ cc_library(
deps = [
":cudnn_convolution_algorithm_picker",
":cudnn_convolution_rewriter",
+ ":cudnn_fused_convolution_rewriter",
":fusion_merger",
":gpu_constants",
":gpu_copy_insertion",
@@ -783,6 +783,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@@ -967,3 +968,19 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
+
+cc_library(
+ name = "cudnn_fused_convolution_rewriter",
+ srcs = ["cudnn_fused_convolution_rewriter.cc"],
+ hdrs = ["cudnn_fused_convolution_rewriter.h"],
+ deps = [
+ ":backend_configs",
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_casting_utils",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
index 640c6392b8..78e14d860e 100644
--- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto
+++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto
@@ -24,4 +24,18 @@ message CudnnConvBackendConfig {
// true, cudnn may choose not to use tensor cores, e.g. because the GPU or
// selected algorithm doesn't support it.
bool tensor_ops_enabled = 2;
+
+ // The scaling factor multiplied with the convolution result.
+ double conv_result_scale = 4;
+
+ // Below are the fields related to cuDNN's fused convolution. Refer to
+ // CudnnConvParams for their meanings.
+
+ // The requested activation (e.g. relu) after the convolution. It is with type
+ // stream_executor::dnn::ActivationMode.
+ int64 activation_mode = 3;
+
+ // The scaling factor multiplied with the side input. If no side input buffer
+ // is provided, this field must be 0.
+ double side_input_scale = 5;
}
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 3a23ac1d63..4effea637d 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -29,37 +29,38 @@ limitations under the License.
namespace xla {
namespace gpu {
-using se::dnn::AlgorithmDesc;
+ConvolutionThunk::ConvolutionThunk(
+ const HloCustomCallInstruction* cudnn_call,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice,
+ BufferAllocation::Slice tuple_result_slice)
+ : Thunk(Kind::kConvolution, cudnn_call),
+ cudnn_call_(cudnn_call),
+ operand_buffers_(std::move(operand_slices)),
+ result_buffer_(result_slice),
+ scratch_buffer_(scratch_slice),
+ tuple_result_buffer_(tuple_result_slice) {}
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
- CudnnConvParams params;
+ std::vector<se::DeviceMemoryBase> operand_se_buffers;
+ for (const auto& buffer : operand_buffers_) {
+ operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer));
+ }
+
+ se::DeviceMemoryBase result_buffer =
+ buffer_allocations.GetDeviceAddress(result_buffer_);
- params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_);
- params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_);
- params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_);
se::DeviceMemoryBase scratch =
buffer_allocations.GetDeviceAddress(scratch_buffer_);
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, &params));
-
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
- TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream));
+ TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_,
+ absl::MakeSpan(operand_se_buffers),
+ result_buffer, scratch, stream));
- // Figure out which of output/input/filter is the result produced by
- // this op, and write the result tuple.
- void* result_ptr = [&] {
- switch (params.kind) {
- case CudnnConvKind::kForward:
- return params.output_buf.opaque();
- case CudnnConvKind::kBackwardInput:
- return params.input_buf.opaque();
- case CudnnConvKind::kBackwardFilter:
- return params.filter_buf.opaque();
- }
- }();
- void* ptrs[] = {result_ptr, scratch.opaque()};
+ void* ptrs[] = {result_buffer.opaque(), scratch.opaque()};
se::DeviceMemory<void*> tuple_addr(
buffer_allocations.GetDeviceAddress(tuple_result_buffer_));
stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr);
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
index d7d1f91fba..f53bc54198 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -42,24 +42,12 @@ class ConvolutionThunk : public Thunk {
// Constructs a thunk for launching a DNN convolution. When run, it will
// write a tuple (result, scratch_memory) into `tuple_result_buffer`.
//
- // Note that "output" here doesn't refer to the output from running this
- // thunk, but rather to the "output" of a hypothetical forward convolution
- // that corresponds to this input+filter+output triple. That is, the result
- // generated by this thunk is "output" for forward convs, "input" for
- // backward-input convs, and "filter" for backward-filter convs.
+ // operand_slices should be in the same order as cudnn_call->operands().
ConvolutionThunk(const HloCustomCallInstruction* cudnn_call,
- BufferAllocation::Slice input_slice,
- BufferAllocation::Slice filter_slice,
- BufferAllocation::Slice output_slice,
+ std::vector<BufferAllocation::Slice> operand_slices,
+ BufferAllocation::Slice result_slice,
BufferAllocation::Slice scratch_slice,
- BufferAllocation::Slice tuple_result_slice)
- : Thunk(Kind::kConvolution, cudnn_call),
- cudnn_call_(cudnn_call),
- input_buffer_(std::move(input_slice)),
- filter_buffer_(std::move(filter_slice)),
- output_buffer_(std::move(output_slice)),
- scratch_buffer_(std::move(scratch_slice)),
- tuple_result_buffer_(std::move(tuple_result_slice)) {}
+ BufferAllocation::Slice tuple_result_slice);
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@@ -71,9 +59,8 @@ class ConvolutionThunk : public Thunk {
private:
const HloCustomCallInstruction* cudnn_call_;
- BufferAllocation::Slice input_buffer_;
- BufferAllocation::Slice filter_buffer_;
- BufferAllocation::Slice output_buffer_;
+ std::vector<BufferAllocation::Slice> operand_buffers_;
+ BufferAllocation::Slice result_buffer_;
BufferAllocation::Slice scratch_buffer_;
BufferAllocation::Slice tuple_result_buffer_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
index 6e2e330edd..c3f58508dd 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h
@@ -52,7 +52,7 @@ namespace gpu {
// The GPU backend does not implement a lowering for the batchnorm HLOs -- it
// expects them to be lowered to cudnn calls via this pass or to HLO soup via
// BatchNormRewriter.
-class CudnnBatchNormRewriter : public HloPassInterface {
+class CudnnBatchNormRewriter : public HloModulePass {
public:
absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index c607aea1a8..7125673887 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -76,54 +76,24 @@ StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
return se::DeviceMemory<uint8>(buffer_addr);
}
-// Determines whether we can safely perform a winograd non-fused convolution for
-// the given input and output shapes. This works around b/68264959, an integer
-// overflow in cuDNNv5 and cuDNNv6.
-bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape,
- const Shape& output_shape,
- const ConvolutionDimensionNumbers& dnums,
- se::StreamExecutor* stream_exec) {
- // Skip this check for cudnn7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
- if (version.ok() && version.ValueOrDie().major_version() >= 7) {
- return true;
- }
-
- int64 batch = input_shape.dimensions(dnums.input_batch_dimension());
- int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension());
- int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0));
- int64 in_cols =
- dnums.input_spatial_dimensions_size() == 1
- ? 1
- : input_shape.dimensions(dnums.input_spatial_dimensions(1));
- int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension());
-
- int64 total_size = CeilOfRatio(batch, int64{16}) *
- std::max(in_depths, out_depths) * in_cols * in_rows *
- sizeof(float);
-
- const int64 threshold = 1L << 31;
- return total_size < threshold;
-}
-
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
- bool with_winograd_nonfused,
se::StreamExecutor* stream_exec) {
std::vector<AlgorithmDesc> algorithms;
+ bool succ = false;
switch (kind) {
case CudnnConvKind::kBackwardFilter:
- CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(
- with_winograd_nonfused, &algorithms));
+ succ =
+ stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kBackwardInput:
- CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(
- with_winograd_nonfused, &algorithms));
+ succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms);
break;
case CudnnConvKind::kForward:
- CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused,
- &algorithms));
+ case CudnnConvKind::kForwardActivation:
+ succ = stream_exec->GetConvolveAlgorithms(true, &algorithms);
break;
}
+ DCHECK(succ);
return algorithms;
}
@@ -177,19 +147,11 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// caching would speed up compilation a lot.
StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
- const HloCustomCallInstruction* instr) {
- CudnnConvParams params;
- TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, &params));
-
- const Shape& input_shape = *params.input_shape;
- const Shape& filter_shape = *params.filter_shape;
- const Shape& output_shape = *params.output_shape;
-
- CHECK_EQ(input_shape.element_type(), filter_shape.element_type());
- CHECK_EQ(input_shape.element_type(), output_shape.element_type());
+ HloCustomCallInstruction* instr) {
// TODO(timshen): for now only check fp16. It can be expanded to other types,
// with some work on the HLO routines.
- const bool cross_check_enabled = input_shape.element_type() == xla::F16;
+ const bool cross_check_enabled =
+ instr->shape().tuple_shapes(0).element_type() == xla::F16;
// Don't run this function concurrently on the same GPU.
//
@@ -221,25 +183,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
allocator = &*se_allocator;
}
- // Allocate space for the input, filter, and output of the convolution. We
- // use a ScratchAllocator for this instead of calling allocator_ directly so
- // that our allocations don't leak.
- ScratchAllocator input_output_allocator(device_ordinal, allocator);
- TF_ASSIGN_OR_RETURN(params.input_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(input_shape)));
- TF_ASSIGN_OR_RETURN(params.filter_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(filter_shape)));
- TF_ASSIGN_OR_RETURN(params.output_buf,
- input_output_allocator.AllocateBytes(
- &stream, ShapeUtil::ByteSizeOf(output_shape)));
-
- if (cross_check_enabled) {
- // Broadcast a constant to the buffer, instead of zeroing the buffer. A
- // non-zero constant is useful for the cross checking, because zero-inputs
- // may not always reveal the bugs.
- const auto initialize_f16 = [&stream](DeviceMemoryBase buffer) {
+ const auto initialize_buffer = [&stream, cross_check_enabled](
+ DeviceMemoryBase buffer) {
+ if (cross_check_enabled) {
+ // Broadcast a constant to the buffer, instead of zeroing the buffer. A
+ // non-zero constant is useful for the cross checking, because zero-inputs
+ // may not always reveal the bugs.
CHECK_EQ(0, (uintptr_t)buffer.opaque() % 4);
size_t left_over_bytes = buffer.size() % 4;
CHECK_EQ(0, left_over_bytes % 2);
@@ -257,51 +206,56 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
DeviceMemoryBase left_over(
static_cast<char*>(buffer.opaque()) + aligned_size, left_over_bytes);
stream.ThenMemcpy(&left_over, halfs, left_over_bytes);
- };
- initialize_f16(params.input_buf);
- initialize_f16(params.filter_buf);
- initialize_f16(params.output_buf);
- } else {
- // Although we don't have evidence this matters, zero out the buffers before
- // autotuning. It's conceivable that using uninitialized memory as the
- // inputs might affect performance if e.g. the inputs contain denormals, and
- // this is easy enough.
- stream.ThenMemZero(&params.input_buf, params.input_buf.size())
- .ThenMemZero(&params.filter_buf, params.filter_buf.size())
- .ThenMemZero(&params.output_buf, params.output_buf.size());
- }
-
- DeviceMemoryBase* result_buf = [&] {
- switch (params.kind) {
- case CudnnConvKind::kBackwardFilter:
- return &params.filter_buf;
- case CudnnConvKind::kBackwardInput:
- return &params.input_buf;
- case CudnnConvKind::kForward:
- return &params.output_buf;
+ } else {
+ // Although we don't have evidence this matters, zero out the buffers
+ // before autotuning. It's conceivable that using uninitialized memory as
+ // the inputs might affect performance if e.g. the inputs contain
+ // denormals, and this is easy enough.
+ stream.ThenMemZero(&buffer, buffer.size());
}
- }();
+ };
+
+ // Allocate space for the input, filter, and output of the convolution. We
+ // use a ScratchAllocator for this instead of calling allocator_ directly so
+ // that our allocations don't leak.
+ ScratchAllocator input_output_allocator(device_ordinal, allocator);
+ std::vector<se::DeviceMemoryBase> operand_buffers;
+ for (const auto* operand : instr->operands()) {
+ TF_ASSIGN_OR_RETURN(auto buffer,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(operand->shape())));
+ initialize_buffer(buffer);
+ operand_buffers.push_back(buffer);
+ }
+ TF_ASSIGN_OR_RETURN(
+ auto result_buffer,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
+ initialize_buffer(result_buffer);
- const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
- input_shape, output_shape, *params.dnums, stream_exec_);
se::dnn::ProfileResult best_result;
int64 best_result_bytes_used = 0;
+ TF_ASSIGN_OR_RETURN(auto backend_config,
+ instr->backend_config<CudnnConvBackendConfig>());
optional<F16BufferComparator> comparator;
// Use the first algorithm that's supported as reference. There isn't a
// particular reason to use it, as any algorithm sufficies. It doesn't make
// this algorithm considered correct, though.
optional<AlgorithmDesc> first_algorithm;
- for (const AlgorithmDesc& alg :
- GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) {
+ TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr));
+ for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) {
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
- params.algorithm = AlgorithmConfig(alg);
- bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream,
- &profile_result)
+ backend_config.set_algorithm(alg.algo_id());
+ backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled());
+ TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config));
+ bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers),
+ result_buffer, &scratch_allocator,
+ &stream, &profile_result)
.ok();
if (launch_ok && profile_result.is_valid()) {
@@ -312,7 +266,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
.xla_gpu_crash_on_verification_failures();
if (comparator.has_value()) {
StatusOr<bool> result = comparator->CompareEqual(
- se::DeviceMemory<Eigen::half>(*result_buf));
+ se::DeviceMemory<Eigen::half>(result_buffer));
if (!result.ok()) {
LOG(ERROR) << "Unable to compare "
<< AlgorithmToString(*first_algorithm) << " against "
@@ -330,7 +284,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
}
} else if (cross_check_enabled) {
auto comp = F16BufferComparator::Create(
- se::DeviceMemory<Eigen::half>(*result_buf), compiler_, allocator,
+ se::DeviceMemory<Eigen::half>(result_buffer), compiler_, allocator,
&stream);
if (comp.ok()) {
comparator.emplace(comp.ConsumeValueOrDie());
@@ -404,13 +358,14 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0),
ShapeUtil::MakeShape(U8, {scratch_bytes})});
- CudnnConvBackendConfig backend_config;
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ instr->backend_config<CudnnConvBackendConfig>());
backend_config.set_algorithm(algorithm);
backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
HloInstruction* new_call = computation->AddInstruction(
- instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0),
- instr->mutable_operand(1)}));
+ instr->CloneWithNewOperands(new_call_shape, instr->operands()));
+
TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config));
// Repackage new_call so it has the same shape as the original call, namely
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index f79b113f8f..aeda2fc7f8 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -30,7 +30,7 @@ namespace gpu {
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
// each and adding explicit scratch space to the CustomCalls.
-class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
+class CudnnConvolutionAlgorithmPicker : public HloModulePass {
public:
// If the `allocator` parameter is not null, we will use it to allocate temp
// memory while timing the various convolution algorithms. If it's null,
@@ -50,7 +50,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- const HloCustomCallInstruction* instr);
+ HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index 3d1266355b..ef29237301 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -35,6 +36,32 @@ namespace gpu {
namespace {
+HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape,
+ HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ int64 feature_group_count) {
+ HloComputation* computation = lhs->parent();
+
+ // This call returns a tuple of (conv_result, scratch_memory), where
+ // conv_result is the actual result of the convolution, and scratch_memory is
+ // temporary memory used by cudnn.
+ //
+ // At the moment, we don't know how much scratch memory this conv is going to
+ // use, so we put u8[0] in this place. Later on another pass will choose
+ // which conv algorithm to use, and at that point we'll modify the shape of
+ // this second tuple element.
+ Shape call_shape =
+ ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
+
+ HloInstruction* custom_call = computation->AddInstruction(
+ HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
+ custom_call->set_window(window);
+ custom_call->set_convolution_dimension_numbers(dnums);
+ custom_call->set_feature_group_count(feature_group_count);
+ return custom_call;
+}
+
bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
const ConvolutionDimensionNumbers& dnums =
conv->convolution_dimension_numbers();
@@ -263,7 +290,7 @@ MatchBackwardInput(HloInstruction* conv) {
!(window_util::HasBaseDilation(conv->window()) &&
(reverse_filter->IsConstant() || is_1x1_filter))) {
VLOG(1) << "Can't match to backwards convolution. Either filter is not "
- "kReverse, or it's not a base-dialted conv with a 1x1 or "
+ "kReverse, or it's not a base-dilated conv with a 1x1 or "
"constant filter.";
return no_match_result;
}
@@ -450,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) {
return std::make_tuple(true, new_window, dnums, rhs);
}
+CudnnConvBackendConfig GetDefaultBackendConfig() {
+ CudnnConvBackendConfig config;
+ config.set_conv_result_scale(1);
+ return config;
+}
+
// Tries to rewrite a single convolution into a call to cudnn.
StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
CHECK_EQ(conv->opcode(), HloOpcode::kConvolution);
@@ -462,24 +495,24 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
std::tie(match, window, dnums) = MatchBackwardFilter(conv);
if (match) {
- return CreateCudnnConvBackwardFilter(
- conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1),
- window, dnums, conv->feature_group_count());
+ return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(),
+ conv->mutable_operand(0), conv->mutable_operand(1),
+ window, dnums, conv->feature_group_count());
}
std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv);
if (match) {
- return CreateCudnnConvBackwardInput(conv->shape(),
- conv->mutable_operand(0), rhs, window,
- dnums, conv->feature_group_count());
+ return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(),
+ conv->mutable_operand(0), rhs, window, dnums,
+ conv->feature_group_count());
}
// If all else fails, try a forward convolution.
if (CanImplementAsCudnnForwardConv(conv)) {
- return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0),
- conv->mutable_operand(1), conv->window(),
- conv->convolution_dimension_numbers(),
- conv->feature_group_count());
+ return CreateCudnnConv(
+ kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0),
+ conv->mutable_operand(1), conv->window(),
+ conv->convolution_dimension_numbers(), conv->feature_group_count());
}
return nullptr;
@@ -489,6 +522,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) {
return false;
}
+ TF_RETURN_IF_ERROR(
+ custom_call->set_backend_config(GetDefaultBackendConfig()));
+
// The CustomCall returns a tuple (conv_result, scratch_memory). Extract out
// the conv result and replace `conv` with it.
TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
index fbe7e98494..8d7c6fdab5 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h
@@ -24,7 +24,7 @@ namespace gpu {
// Rewrites plain convolutions, backwards-filter convolutions, and
// backwards-input convolutions into CustomCall HLOs that call into cuDNN.
-class CudnnConvolutionRewriter : public HloPassInterface {
+class CudnnConvolutionRewriter : public HloModulePass {
public:
absl::string_view name() const override {
return "cudnn-convolution-rewriter";
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 2a86ac265e..89dd1bb272 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -37,6 +39,42 @@ using se::dnn::FilterDescriptor;
using se::dnn::FilterLayout;
using se::dnn::ProfileResult;
+struct CudnnConvParams {
+ // Here are the fields related to cuDNN's fused convolution. The result thus
+ // is defined as:
+ // activation(conv_result_scale * conv(x, w) +
+ // side_input_scale * side_input + broadcast(bias))
+ //
+ // The most common fused conv is conv forward + relu/identity, for example.
+ //
+ // bias_buf is a single-dimensional array, with the length equal to the number
+ // of output features. It'll be broadcasted to the output shape in order to be
+ // added to the final results.
+ //
+ // side_input_buf, if valid, must have the same shape as the output buffer.
+ struct FusionParams {
+ se::dnn::ActivationMode mode;
+ double side_input_scale;
+ se::DeviceMemoryBase bias_buf;
+ se::DeviceMemoryBase side_input_buf; // nullable
+ };
+
+ CudnnConvKind kind;
+ const Shape* input_shape;
+ const Shape* filter_shape;
+ const Shape* output_shape;
+ se::DeviceMemoryBase input_buf;
+ se::DeviceMemoryBase filter_buf;
+ se::DeviceMemoryBase output_buf;
+ const Window* window;
+ const ConvolutionDimensionNumbers* dnums;
+ int64 feature_group_count;
+ se::dnn::AlgorithmConfig algorithm;
+ double conv_result_scale;
+
+ absl::optional<FusionParams> fusion;
+};
+
// A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
// returning it (in its entirety) the first time Allocate() is called.
class ScratchBufAllocator : public se::ScratchAllocator {
@@ -92,9 +130,9 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
- VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
- VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
- VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }";
+ VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape);
+ VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape);
+ VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape);
VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";
@@ -186,23 +224,73 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
switch (kind) {
case CudnnConvKind::kForward:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveWithAlgorithm(
input_descriptor, input_buf, filter_descriptor, filter_buf,
convolution_descriptor, output_descriptor, &output_buf,
scratch_allocator, algorithm, profile_result);
break;
case CudnnConvKind::kBackwardInput:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveBackwardDataWithAlgorithm(
filter_descriptor, filter_buf, output_descriptor, output_buf,
convolution_descriptor, input_descriptor, &input_buf,
scratch_allocator, algorithm, profile_result);
break;
case CudnnConvKind::kBackwardFilter:
+ if (params.conv_result_scale != 1) {
+ return InternalError(
+ "StreamExecutor doesn't support scaled convolution: %lf.",
+ params.conv_result_scale);
+ }
stream->ThenConvolveBackwardFilterWithAlgorithm(
input_descriptor, input_buf, output_descriptor, output_buf,
convolution_descriptor, filter_descriptor, &filter_buf,
scratch_allocator, algorithm, profile_result);
break;
+ case CudnnConvKind::kForwardActivation: {
+ BatchDescriptor bias_desc;
+ bias_desc.set_count(1)
+ .set_height(1)
+ .set_width(1)
+ .set_feature_map_count(
+ output_shape.dimensions(dnums.output_feature_dimension()))
+ .set_layout(output_dl);
+
+ se::DeviceMemory<T> side_input(params.fusion->side_input_buf);
+ // If there is no side input, use output as the side input.
+ if (side_input.is_null()) {
+ if (params.fusion->side_input_scale != 0) {
+ return InternalError(
+ "Side input scale is not 0, yet no side input buffer is "
+ "provided");
+ }
+ // Since side-input scale is 0, the values in the side input don't
+ // matter. The simplest thing to do would be to pass in a null buffer
+ // for the side input, but cudnn doesn't allow this. cudnn does promise
+ // that if side-input-scale is 0 the side input won't be read, so we
+ // just pass in the output buffer, since it's handy and has the correct
+ // size.
+ side_input = output_buf;
+ }
+
+ stream->ThenFusedConvolveWithAlgorithm(
+ input_descriptor, input_buf, params.conv_result_scale,
+ filter_descriptor, filter_buf, convolution_descriptor, side_input,
+ params.fusion->side_input_scale, bias_desc,
+ DeviceMemory<T>(params.fusion->bias_buf), params.fusion->mode,
+ output_descriptor, &output_buf, scratch_allocator, algorithm,
+ profile_result);
+ break;
+ }
}
if (!stream->ok()) {
@@ -214,32 +302,104 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params,
return Status::OK();
}
-} // anonymous namespace
+// Returns the cudnn convolution parameters generated from conv, which must be a
+// custom-call to a cudnn convolution.
+StatusOr<CudnnConvParams> GetCudnnConvParams(
+ const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer) {
+ CudnnConvParams params;
-string CudnnConvKindToString(CudnnConvKind kind) {
- switch (kind) {
- case CudnnConvKind::kForward:
- return "forward";
- case CudnnConvKind::kBackwardFilter:
- return "backward_filter";
- case CudnnConvKind::kBackwardInput:
- return "backward_input";
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ const auto& target = conv->custom_call_target();
+ const auto& lhs_shape = conv->operand(0)->shape();
+ const auto& rhs_shape = conv->operand(1)->shape();
+ const auto& conv_result_shape = conv->shape().tuple_shapes(0);
+
+ params.window = &conv->window();
+ params.dnums = &conv->convolution_dimension_numbers();
+ params.feature_group_count = conv->feature_group_count();
+ params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
+ backend_config.algorithm(), backend_config.tensor_ops_enabled()));
+ params.conv_result_scale = backend_config.conv_result_scale();
+
+ if (target == kCudnnConvForwardCallTarget) {
+ params.kind = CudnnConvKind::kForward;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ } else if (target == kCudnnConvBackwardInputCallTarget) {
+ params.kind = CudnnConvKind::kBackwardInput;
+ params.input_shape = &conv_result_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &lhs_shape;
+ params.input_buf = result_buffer;
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = operand_buffers[0];
+ } else if (target == kCudnnConvBackwardFilterCallTarget) {
+ params.kind = CudnnConvKind::kBackwardFilter;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &conv_result_shape;
+ params.output_shape = &rhs_shape;
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = result_buffer;
+ params.output_buf = operand_buffers[1];
+ } else if (target == kCudnnConvBiasActivationForwardCallTarget) {
+ params.kind = CudnnConvKind::kForwardActivation;
+ params.input_shape = &lhs_shape;
+ params.filter_shape = &rhs_shape;
+ params.output_shape = &conv_result_shape;
+ params.fusion.emplace();
+ auto& fusion = *params.fusion;
+ if (backend_config.activation_mode() <
+ static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) {
+ fusion.mode = static_cast<se::dnn::ActivationMode>(
+ backend_config.activation_mode());
+ } else {
+ return InternalError("Bad activation mode: %s",
+ backend_config.ShortDebugString());
+ }
+ fusion.side_input_scale = backend_config.side_input_scale();
+ params.input_buf = operand_buffers[0];
+ params.filter_buf = operand_buffers[1];
+ params.output_buf = result_buffer;
+ params.fusion->bias_buf = operand_buffers[2];
+ if (operand_buffers.size() >= 4) {
+ params.fusion->side_input_buf = operand_buffers[3];
+ }
+ } else {
+ return InternalError("Unexpected custom call target: %s", target);
}
+ return params;
}
-Status RunCudnnConvolution(CudnnConvParams params,
+} // anonymous namespace
+
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
ScratchBufAllocator scratch_allocator(scratch_buf);
- return RunCudnnConvolution(params, &scratch_allocator, stream,
- profile_result);
+ return RunCudnnConvolution(conv, operand_buffers, result_buffer,
+ &scratch_allocator, stream, profile_result);
}
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result) {
- PrimitiveType output_primitive_type = params.output_shape->element_type();
+ TF_ASSIGN_OR_RETURN(CudnnConvParams params,
+ GetCudnnConvParams(conv, operand_buffers, result_buffer));
+
+ PrimitiveType output_primitive_type =
+ conv->shape().tuple_shapes(0).element_type();
switch (output_primitive_type) {
case F16:
return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator,
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
index 381aa37a1b..61aec1cecc 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h
@@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -27,52 +30,8 @@ namespace gpu {
// This file contains low-level routines for running cudnn convolutions.
-// Different types of convolutions supported by cudnn.
-//
-// A way to think about these is that a convolution is defined by three arrays
-// -- the "input", the "filter", and the "output" -- and given any two of these,
-// we can compute the third. For example, a backward-input convolution takes as
-// input a filter and an "output" and produces an "input" such that if one were
-// to do a forward convolution of "input" using filter, the result would be
-// something with the same shape as "output".
-//
-// This way of thinking is not correct if you look at the values produced. For
-// example, a backward-input convolution is not actually the mathematical
-// inverse of a forward convolution. But it's right as far as the shapes and
-// "connectivity" (i.e. which elements of the input affect which elements of
-// the output) are concerned.
-enum class CudnnConvKind {
- kForward, // input + filter => output
- kBackwardInput, // filter + output => input
- kBackwardFilter, // input + output => filter
-};
-
-struct CudnnConvParams {
- CudnnConvKind kind;
- const Shape* input_shape;
- const Shape* filter_shape;
- const Shape* output_shape;
- se::DeviceMemoryBase input_buf;
- se::DeviceMemoryBase filter_buf;
- se::DeviceMemoryBase output_buf;
- const Window* window;
- const ConvolutionDimensionNumbers* dnums;
- int64 feature_group_count;
- se::dnn::AlgorithmConfig algorithm;
-};
-
-// Converts a CudnnConvKind value to a string.
-string CudnnConvKindToString(CudnnConvKind kind);
-
// Calls into cudnn to run the specified convolution.
//
-// Note that depending on the value of CudnnConvKind, the result of this call
-// may be written into input_buf, filter_buf, or output_buf!
-//
-// At the moment convolution with half data type is implemented with cudnn
-// PSEUDO_HALF configuration, that is, the input values are half and the
-// internal computation type is float.
-//
// We provide one overload which takes a scratch buffer, and another which takes
// an allocator which is responsible for allocating the scratch space. In
// theory the second one shouldn't be necessary -- users of this function could
@@ -83,11 +42,15 @@ string CudnnConvKindToString(CudnnConvKind kind);
// allocator and take note of how much memory is used. The next time you call
// the same conv, you can provide an explicitly preallocated scratch buffer of
// that size, if you like.
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::DeviceMemoryBase scratch_buf, se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
-Status RunCudnnConvolution(CudnnConvParams params,
+Status RunCudnnConvolution(const HloCustomCallInstruction* conv,
+ absl::Span<se::DeviceMemoryBase> operand_buffers,
+ se::DeviceMemoryBase result_buffer,
se::ScratchAllocator* scratch_allocator,
se::Stream* stream,
se::dnn::ProfileResult* profile_result = nullptr);
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
new file mode 100644
index 0000000000..3761c19cfc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc
@@ -0,0 +1,278 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Describes a matched pattern:
+// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+// Where side_input has the shape of output buffer, and bias is a 1D array with
+// the dimension of number of output features.
+struct ConvWithRelu {
+ HloInstruction* maximum;
+ HloCustomCallInstruction* conv;
+ HloInstruction* bias;
+ HloInstruction* side_input;
+ HloConstantInstruction* alpha_conv;
+ HloConstantInstruction* alpha_side_input;
+};
+
+absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::GetTupleElement;
+ using match::Maximum;
+ using match::MultiplyAnyOrder;
+ using match::Op;
+
+ // The pattern we want to match:
+ // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias));
+ //
+ // With its variants involving commute/reassociation of adds, multiplies, and
+ // max, and omission of alpha1, side_input, alpha2, or bias.
+
+ HloInstruction* relu_input;
+
+ // Match max(0, relu_input).
+ auto zero_pattern = Broadcast(match::ConstantScalar(0));
+ if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) &&
+ !Match(instr, Maximum(Op(&relu_input), zero_pattern))) {
+ return absl::nullopt;
+ }
+ HloInstruction* conv_instr = nullptr;
+ HloInstruction* alpha_conv_instr = nullptr;
+ HloInstruction* alpha_side_input_instr = nullptr;
+ HloInstruction* bias_broadcast_instr = nullptr;
+ HloInstruction* bias = nullptr;
+ HloInstruction* side_input = nullptr;
+
+ // These nodes will not be in the returned value, but we need to check them
+ // for single use.
+ HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr,
+ *mul1 = nullptr, *mul2 = nullptr;
+
+ const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias));
+ const auto conv_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr));
+ auto conv_pattern = GetTupleElement(
+ &gte, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern);
+ }();
+ const auto side_input_pattern = [&] {
+ auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr));
+ // If bias is already matched, match arbitrary additional input as side
+ // input. Note this may force a cheap operation (e.g. broadcast) to be
+ // materialized into a large buffer, as large as the output buffer.
+ //
+ // TODO(timshen): If in practice there are significant false positives, we
+ // should fix it.
+ auto side_input_pattern = Op(&side_input);
+ return AnyOf<HloInstruction>(
+ MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern),
+ side_input_pattern);
+ }();
+
+ {
+ // Try to match any of the following form of add, in any association:
+ // addends[0]
+ // addends[0] + addends[1]
+ // addends[0] + addends[1] + addends[2]
+ //
+ // Then try to match each addend with one of the three patterns: bias, conv,
+ // or side_input. Notice that side_input matching must go last, as it
+ // also matches a conv or a bias.
+ HloInstruction* addends[3] = {nullptr, nullptr, nullptr};
+ auto add3_pattern = [&] {
+ auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1]));
+ return AnyOf<HloInstruction>(
+ AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern,
+ Op(&addends[0]));
+ }();
+ CHECK(Match(relu_input, add3_pattern));
+ for (auto addend : addends) {
+ if (addend) {
+ if (bias == nullptr && Match(addend, bias_pattern)) {
+ CHECK(bias);
+ } else if (conv_instr == nullptr && Match(addend, conv_pattern)) {
+ CHECK(conv_instr);
+ } else if (side_input == nullptr && Match(addend, side_input_pattern)) {
+ CHECK(side_input);
+ } else {
+ return absl::nullopt;
+ }
+ }
+ }
+ }
+
+ if (conv_instr == nullptr) {
+ return absl::nullopt;
+ }
+
+ for (HloInstruction* instr :
+ {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) {
+ if (instr && instr->user_count() > 1) {
+ return absl::nullopt;
+ }
+ }
+
+ auto conv = Cast<HloCustomCallInstruction>(conv_instr);
+ auto bias_broadcast =
+ CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr);
+
+ if (conv->custom_call_target() != kCudnnConvForwardCallTarget) {
+ return absl::nullopt;
+ }
+
+ if (bias_broadcast) {
+ // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}.
+ if (bias_broadcast_instr->dimensions().size() != 1) {
+ return absl::nullopt;
+ }
+ if (bias_broadcast_instr->dimensions(0) !=
+ conv->convolution_dimension_numbers().output_feature_dimension()) {
+ return absl::nullopt;
+ }
+ }
+
+ return ConvWithRelu{
+ instr,
+ conv,
+ bias,
+ side_input,
+ CastOrNull<HloConstantInstruction>(alpha_conv_instr),
+ CastOrNull<HloConstantInstruction>(alpha_side_input_instr)};
+}
+
+StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu(
+ ConvWithRelu match) {
+ auto conv = match.conv;
+
+ HloComputation* computation = conv->parent();
+ PrimitiveType element_type = conv->operand(0)->shape().element_type();
+
+ const auto get_alpha_value =
+ [](HloConstantInstruction* instr) -> StatusOr<double> {
+ TF_ASSIGN_OR_RETURN(
+ auto alpha,
+ Cast<HloConstantInstruction>(instr)->literal().Convert(F64));
+ return alpha.GetFirstElement<double>();
+ };
+
+ double alpha_conv = 1;
+ if (match.alpha_conv) {
+ TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv));
+ }
+
+ double alpha_side_input;
+ if (match.side_input) {
+ if (match.alpha_side_input) {
+ TF_ASSIGN_OR_RETURN(alpha_side_input,
+ get_alpha_value(match.alpha_side_input));
+ } else {
+ alpha_side_input = 1;
+ }
+ } else {
+ CHECK(match.alpha_side_input == nullptr);
+ alpha_side_input = 0;
+ }
+
+ auto bias = match.bias;
+ if (!bias) {
+ auto zero = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::Zero(element_type)));
+
+ int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions(
+ conv->convolution_dimension_numbers().output_feature_dimension());
+ bias = computation->AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShapeWithDescendingLayout(element_type,
+ {num_output_feature}),
+ zero, {}));
+ }
+
+ CHECK(bias);
+ std::vector<HloInstruction*> args = {conv->mutable_operand(0),
+ conv->mutable_operand(1), bias};
+ if (match.side_input) {
+ args.push_back(match.side_input);
+ }
+ auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall(
+ conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget));
+ new_conv->set_window(conv->window());
+ new_conv->set_convolution_dimension_numbers(
+ conv->convolution_dimension_numbers());
+ TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config,
+ conv->backend_config<CudnnConvBackendConfig>());
+ config.set_activation_mode(
+ static_cast<int64>(se::dnn::ActivationMode::kRelu));
+ config.set_conv_result_scale(alpha_conv);
+ config.set_side_input_scale(alpha_side_input);
+ TF_RETURN_IF_ERROR(new_conv->set_backend_config(config));
+
+ VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name();
+ return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0),
+ new_conv, 0);
+}
+
+} // namespace
+
+StatusOr<bool> CudnnFusedConvolutionRewriter::Run(HloModule* module) {
+ bool changed = false;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ std::vector<ConvWithRelu> matches;
+ int num_forward_convs = 0;
+ for (auto instr : computation->instructions()) {
+ auto match = FindConvWithRelu(instr);
+ if (match.has_value()) {
+ matches.push_back(*match);
+ }
+ if (auto call = DynCast<HloCustomCallInstruction>(instr)) {
+ if (call->custom_call_target() == kCudnnConvForwardCallTarget) {
+ num_forward_convs++;
+ }
+ }
+ }
+ VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size()
+ << " out of " << num_forward_convs << " forward convs.";
+ std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>>
+ replacements;
+ for (const ConvWithRelu& match : matches) {
+ TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match));
+ replacements.push_back({match.maximum, std::move(new_instr)});
+ changed = true;
+ }
+ for (auto& replacement : replacements) {
+ TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
+ replacement.first, std::move(replacement.second)));
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
new file mode 100644
index 0000000000..bd12aadded
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+namespace gpu {
+
+class CudnnFusedConvolutionRewriter : public HloModulePass {
+ public:
+ absl::string_view name() const override {
+ return "cudnn-fused-convolution-rewriter";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
index 7e3f5775b8..f19996edfe 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h
@@ -32,7 +32,7 @@ namespace gpu {
// 2) The result of merging the fusion instruction into its users would not
// increase bytes transferred.
//
-class FusionMerger : public HloPassInterface {
+class FusionMerger : public HloModulePass {
public:
absl::string_view name() const override { return "fusion merger"; }
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index 75f414e47f..79c74e7e8b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -34,15 +34,6 @@ namespace xla {
namespace gpu {
-StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
- HloInstruction* hlo) {
- HloInstruction*& copy = hlo_to_copy_map_[hlo];
- if (copy == nullptr) {
- TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo));
- }
- return copy;
-}
-
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
CopyInsertion generic_copy_insertion;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
index 8ffae18fe8..4c7e38ffeb 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h
@@ -25,20 +25,11 @@ namespace gpu {
// Besides the modifications made by the generic xla::CopyInsertion, this
// GPU-specific copy insertion also materializes operands of library calls by
// inserting kCopy instructions.
-class GpuCopyInsertion : public HloPassInterface {
+class GpuCopyInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "copy-insertion"; }
StatusOr<bool> Run(HloModule* module) override;
-
- protected:
- // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making
- // duplicate copies.
- StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
-
- // A map containing all copies inserted to materialize operands of library
- // calls. The key is the copied instruction and the value is the copy.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
index bbb3340760..9c64b4d10c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h
@@ -23,7 +23,7 @@ namespace xla {
// his pass should run early in the HLO pipeline and checks for HLO constructs
// which are not supported by the GPU backend and cannot be removed via HLO
// transformations (eg, sparse layouts).
-class GpuHloSupportChecker : public HloPassInterface {
+class GpuHloSupportChecker : public HloModulePass {
public:
GpuHloSupportChecker() = default;
~GpuHloSupportChecker() override = default;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index d033faee8d..74352f26aa 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -21,8 +21,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -90,27 +92,33 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// operands and the output shape. Depending on the underlying algorithm, one of
// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
- HloInstruction* instr, LayoutConstraints* constraints) {
- CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString();
- Shape input_shape;
- Shape filter_shape;
- Shape output_shape;
- const auto& target = instr->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
- input_shape = instr->operand(0)->shape();
- filter_shape = instr->operand(1)->shape();
- output_shape = instr->shape().tuple_shapes(0);
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_shape = instr->shape().tuple_shapes(0);
- filter_shape = instr->operand(1)->shape();
- output_shape = instr->operand(0)->shape();
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_shape = instr->operand(0)->shape();
- filter_shape = instr->shape().tuple_shapes(0);
- output_shape = instr->operand(1)->shape();
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << instr->custom_call_target();
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints) {
+ Shape lhs_shape = instr->operand(0)->shape();
+ Shape rhs_shape = instr->operand(1)->shape();
+ Shape result_shape = instr->shape().tuple_shapes(0);
+
+ Shape* input_shape;
+ Shape* filter_shape;
+ Shape* output_shape;
+
+ TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr));
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ case CudnnConvKind::kForwardActivation:
+ input_shape = &lhs_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &result_shape;
+ break;
+ case CudnnConvKind::kBackwardInput:
+ input_shape = &result_shape;
+ filter_shape = &rhs_shape;
+ output_shape = &lhs_shape;
+ break;
+ case CudnnConvKind::kBackwardFilter:
+ input_shape = &lhs_shape;
+ filter_shape = &result_shape;
+ output_shape = &rhs_shape;
+ break;
}
{
@@ -127,8 +135,9 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
}
TF_ASSIGN_OR_RETURN(
- std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(),
- *output_shape.mutable_layout()),
+ std::tie(*input_shape->mutable_layout(),
+ *filter_shape->mutable_layout(),
+ *output_shape->mutable_layout()),
StreamExecutorConvLayoutsToXlaLayouts(
instr->convolution_dimension_numbers(), input, filter, output));
}
@@ -141,24 +150,23 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
instr, /*index=*/{0}));
// Set layouts of the instructions' shapes.
- if (target == kCudnnConvForwardCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(output_shape.layout(), *call_result_buf));
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(input_shape.layout(), *call_result_buf));
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1));
- TF_RETURN_IF_ERROR(
- constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf));
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << instr->custom_call_target();
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1));
+ TF_RETURN_IF_ERROR(
+ constraints->SetBufferLayout(result_shape.layout(), *call_result_buf));
+ // instr->operand(2), if exists, is the bias buffer. There is no need to
+ // assign layout to it, as it has only one dimension.
+
+ // instr->opernad(3), if exists, is the side input buffer.
+ if (instr->operand_count() == 4) {
+ if (kind != CudnnConvKind::kForwardActivation) {
+ return InternalError(
+ "Invalid convolution. Conv has a side input, but kind is not fused "
+ "conv forward: %s",
+ instr->ToString());
+ }
+ // The side input layout must match the output layout.
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3));
}
return Status::OK();
}
@@ -173,8 +181,8 @@ Status GpuLayoutAssignment::AddBackendConstraints(
++iterator) {
HloInstruction* instruction = *iterator;
if (IsCustomCallToDnnConvolution(*instruction)) {
- TF_RETURN_IF_ERROR(
- AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
+ TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall(
+ Cast<HloCustomCallInstruction>(instruction), constraints));
}
// For batched dot we require the default layout.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index ce24af1cf8..e2b96a81d4 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_
#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -47,7 +48,7 @@ class GpuLayoutAssignment : public LayoutAssignment {
private:
Status AddBackendConstraintsToDnnConvCustomCall(
- HloInstruction* instr, LayoutConstraints* constraints);
+ HloCustomCallInstruction* instr, LayoutConstraints* constraints);
se::StreamExecutor* stream_executor_;
};
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 22f43bc08b..ec3d8f9405 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -129,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget =
"__cudnn$convBackwardInput";
const char* const kCudnnConvBackwardFilterCallTarget =
"__cudnn$convBackwardFilter";
+const char* const kCudnnConvBiasActivationForwardCallTarget =
+ "__cudnn$convBiasActivationForward";
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
if (hlo.opcode() != HloOpcode::kCustomCall) {
@@ -137,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
const auto& target = hlo.custom_call_target();
return target == kCudnnConvForwardCallTarget ||
target == kCudnnConvBackwardInputCallTarget ||
- target == kCudnnConvBackwardFilterCallTarget;
+ target == kCudnnConvBackwardFilterCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget;
}
bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
@@ -145,59 +148,6 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
IsCustomCallToDnnConvolution(hlo);
}
-static HloInstruction* CreateCudnnConv(const char* call_target,
- const Shape& shape, HloInstruction* lhs,
- HloInstruction* rhs,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- HloComputation* computation = lhs->parent();
-
- // This call returns a tuple of (conv_result, scratch_memory), where
- // conv_result is the actual result of the convolution, and scratch_memory is
- // temporary memory used by cudnn.
- //
- // At the moment, we don't know how much scratch memory this conv is going to
- // use, so we put u8[0] in this place. Later on another pass will choose
- // which conv algorithm to use, and at that point we'll modify the shape of
- // this second tuple element.
- Shape call_shape =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})});
-
- HloInstruction* custom_call = computation->AddInstruction(
- HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target));
- custom_call->set_window(window);
- custom_call->set_convolution_dimension_numbers(dnums);
- custom_call->set_feature_group_count(feature_group_count);
- return custom_call;
-}
-
-HloInstruction* CreateCudnnConvForward(const Shape& shape,
- HloInstruction* input,
- HloInstruction* kernel,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel,
- window, dnums, feature_group_count);
-}
-
-HloInstruction* CreateCudnnConvBackwardInput(
- const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output,
- reverse_filter, window, dnums, feature_group_count);
-}
-
-HloInstruction* CreateCudnnConvBackwardFilter(
- const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count) {
- return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input,
- output, window, dnums, feature_group_count);
-}
-
bool IsReductionToVector(const HloInstruction& reduce) {
if (HloOpcode::kReduce != reduce.opcode()) {
return false;
@@ -288,41 +238,35 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
value->getType());
}
-Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
- CudnnConvParams* params) {
- TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
- custom_call->backend_config<CudnnConvBackendConfig>());
- const auto& target = custom_call->custom_call_target();
- const auto& lhs_shape = custom_call->operand(0)->shape();
- const auto& rhs_shape = custom_call->operand(1)->shape();
- const auto& conv_result_shape = custom_call->shape().tuple_shapes(0);
-
- params->window = &custom_call->window();
- params->dnums = &custom_call->convolution_dimension_numbers();
- params->feature_group_count = custom_call->feature_group_count();
- params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
- backend_config.algorithm(), backend_config.tensor_ops_enabled()));
-
+StatusOr<CudnnConvKind> GetCudnnConvKind(
+ const HloCustomCallInstruction* instr) {
+ absl::string_view target = instr->custom_call_target();
if (target == kCudnnConvForwardCallTarget) {
- params->kind = CudnnConvKind::kForward;
- params->input_shape = &lhs_shape;
- params->filter_shape = &rhs_shape;
- params->output_shape = &conv_result_shape;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- params->kind = CudnnConvKind::kBackwardInput;
- params->input_shape = &conv_result_shape;
- params->filter_shape = &rhs_shape;
- params->output_shape = &lhs_shape;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- params->kind = CudnnConvKind::kBackwardFilter;
- params->input_shape = &lhs_shape;
- params->filter_shape = &conv_result_shape;
- params->output_shape = &rhs_shape;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
+ return CudnnConvKind::kForward;
+ }
+ if (target == kCudnnConvBackwardInputCallTarget) {
+ return CudnnConvKind::kBackwardInput;
+ }
+ if (target == kCudnnConvBackwardFilterCallTarget) {
+ return CudnnConvKind::kBackwardFilter;
+ }
+ if (target == kCudnnConvBiasActivationForwardCallTarget) {
+ return CudnnConvKind::kForwardActivation;
+ }
+ return InternalError("Unexpected call target: %s", target);
+}
+
+string CudnnConvKindToString(CudnnConvKind kind) {
+ switch (kind) {
+ case CudnnConvKind::kForward:
+ return "forward";
+ case CudnnConvKind::kBackwardFilter:
+ return "backward_filter";
+ case CudnnConvKind::kBackwardInput:
+ return "backward_input";
+ case CudnnConvKind::kForwardActivation:
+ return "forward with activation";
}
- return Status::OK();
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index 09c455cc1e..a64a616ab1 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -20,7 +20,6 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
-#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
@@ -30,6 +29,33 @@ limitations under the License.
namespace xla {
namespace gpu {
+// Different types of convolutions supported by cudnn.
+//
+// A way to think about these is that a convolution is defined by three arrays
+// -- the "input", the "filter", and the "output" -- and given any two of these,
+// we can compute the third. For example, a backward-input convolution takes as
+// input a filter and an "output" and produces an "input" such that if one were
+// to do a forward convolution of "input" using filter, the result would be
+// something with the same shape as "output".
+//
+// This way of thinking is not correct if you look at the values produced. For
+// example, a backward-input convolution is not actually the mathematical
+// inverse of a forward convolution. But it's right as far as the shapes and
+// "connectivity" (i.e. which elements of the input affect which elements of
+// the output) are concerned.
+enum class CudnnConvKind {
+ kForward, // input + filter => output
+ kBackwardInput, // filter + output => input
+ kBackwardFilter, // input + output => filter
+ kForwardActivation, // activation(conv(input, filter) + broadcast(bias) +
+ // (optionally) side_input) => output
+};
+
+StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
+
+// Converts a CudnnConvKind value to a string.
+string CudnnConvKindToString(CudnnConvKind kind);
+
constexpr int64 kWarpSize = 32;
// Returns true if `hlo` will be implemented as a call to BLAS gemm.
@@ -95,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
extern const char* const kCudnnConvForwardCallTarget;
extern const char* const kCudnnConvBackwardInputCallTarget;
extern const char* const kCudnnConvBackwardFilterCallTarget;
+extern const char* const kCudnnConvBiasActivationForwardCallTarget;
// Returns true if `hlo` will be implemented as a call to a cuDNN convolution
// routine.
@@ -104,28 +131,6 @@ extern const char* const kCudnnConvBackwardFilterCallTarget;
// kConvolution opcode.
bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
-// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv.
-// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If
-// you want just the conv result, you'll need to get-tuple-element the value
-// returned by this function.
-//
-// The created cudnn call will use the default cudnn algorithm and no scratch
-// space.
-HloInstruction* CreateCudnnConvForward(const Shape& shape,
- HloInstruction* input,
- HloInstruction* kernel,
- const Window& window,
- const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-HloInstruction* CreateCudnnConvBackwardInput(
- const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-HloInstruction* CreateCudnnConvBackwardFilter(
- const Shape& shape, HloInstruction* input, HloInstruction* output,
- const Window& window, const ConvolutionDimensionNumbers& dnums,
- int64 feature_group_count);
-
// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
// or cuDNN convolution.
bool ImplementedAsLibraryCall(const HloInstruction& hlo);
@@ -150,11 +155,6 @@ llvm::Value* EmitPrintf(absl::string_view fmt,
llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
llvm::IRBuilder<>* builder);
-// Populates params using conv, which must be a custom-call to a cudnn
-// convolution. Does not modify any buffers in the params.
-Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call,
- CudnnConvParams* params);
-
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b669881026..c792dd2ddb 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (IsCustomCallToDnnConvolution(*custom_call)) {
const auto& assn = ir_emitter_context_->buffer_assignment();
- auto lhs_slice = GetAllocationSlice(*custom_call->operand(0));
- auto rhs_slice = GetAllocationSlice(*custom_call->operand(1));
+ std::vector<BufferAllocation::Slice> operand_slices;
+ operand_slices.reserve(custom_call->operand_count());
+ for (const auto* operand : custom_call->operands()) {
+ operand_slices.push_back(GetAllocationSlice(*operand));
+ }
auto tuple_result_slice = GetAllocationSlice(*custom_call);
auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie();
auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie();
- const auto& target = custom_call->custom_call_target();
- BufferAllocation::Slice input_slice, filter_slice, output_slice;
-
- if (target == kCudnnConvForwardCallTarget) {
- input_slice = lhs_slice;
- filter_slice = rhs_slice;
- output_slice = conv_result_slice;
- } else if (target == kCudnnConvBackwardInputCallTarget) {
- input_slice = conv_result_slice;
- filter_slice = rhs_slice;
- output_slice = lhs_slice;
- } else if (target == kCudnnConvBackwardFilterCallTarget) {
- input_slice = lhs_slice;
- filter_slice = conv_result_slice;
- output_slice = rhs_slice;
- } else {
- LOG(FATAL) << "Unexpected custom call target: "
- << custom_call->custom_call_target();
- }
-
thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>(
- Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice,
- output_slice, scratch_slice, tuple_result_slice));
+ Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices),
+ conv_result_slice, scratch_slice, tuple_result_slice));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index dfdcf1875d..0b3b429710 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h"
+#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h"
@@ -208,6 +209,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
pipeline.AddPass<CudnnConvolutionRewriter>();
+ pipeline.AddPass<CudnnFusedConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
if (IsVoltaOrLater(*stream_exec)) {
pipeline.AddPass<PadForTensorCores>();
@@ -402,7 +404,7 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
LOG(WARNING)
<< "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "."
<< vdot
- << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to "
+ << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to "
"miscompile XLA code, leading to incorrect results or "
"invalid-address errors.\n\nYou do not need to update to CUDA "
"9.2.88; cherry-picking the ptxas binary is sufficient.";
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
index b0061fa655..e3869b5c36 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -36,15 +37,32 @@ static constexpr int64 kDesiredNumFeaturesFactor = 8;
// there's additional room for speedups. Achieving those speedups without also
// slowing other things down will likely require a more sophisticated heuristic,
// possibly some form of auto-tuning.
-static constexpr double kMaxBytesTouchedIncrease = 1.2;
+//
+// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4"
+// special case inside PadShape won't fire.
+static constexpr double kMaxBytesTouchedIncrease = 1.35;
// Pads the given dimensions in the given shape up to a multiple of
// kDesiredNumFeaturesFactor.
static Shape PadShape(Shape s, absl::Span<const int64> dims) {
for (int64 dim : dims) {
int64 dim_to_pad_size = s.dimensions(dim);
- int64 new_dim_to_pad_size =
- RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+
+ // Round dim_to_pad_size up to the next multiple of
+ // kDesiredNumFeaturesFactor.
+ //
+ // Special case: dims of size 3 are rounded up to 4, not
+ // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia),
+ // this helps, but as of writing, it's not supported by anything in the
+ // cudnn docs.
+ int64 new_dim_to_pad_size;
+ if (dim_to_pad_size == 3) {
+ new_dim_to_pad_size = 4;
+ } else {
+ new_dim_to_pad_size =
+ RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor);
+ }
+
s.set_dimensions(dim, new_dim_to_pad_size);
}
return s;
@@ -209,7 +227,11 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) {
std::vector<HloInstruction*> convs;
for (HloInstruction* instr : comp->instructions()) {
if (IsCustomCallToDnnConvolution(*instr) &&
- instr->operand(0)->shape().element_type() == F16) {
+ instr->operand(0)->shape().element_type() == F16 &&
+ // TODO(timshen): Disable for fused conv for now. Implement it if it's
+ // needed.
+ Cast<HloCustomCallInstruction>(instr)->custom_call_target() !=
+ kCudnnConvBiasActivationForwardCallTarget) {
convs.push_back(instr);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
index 11dc56a64f..e592a3774e 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h
@@ -30,7 +30,7 @@ namespace gpu {
// targeting before running this pass.
//
// TODO(jlebar): Also pad dots.
-class PadForTensorCores : public HloPassInterface {
+class PadForTensorCores : public HloModulePass {
public:
absl::string_view name() const override { return "pad for tensor cores"; }
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 2a6415d0b6..b42a19e3a2 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -30,7 +30,8 @@ namespace gpu {
namespace {
bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
- CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget);
+ CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget ||
+ conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget);
return window_util::HasSymmetricPadding(conv.window()) &&
!window_util::HasNegativePadding(conv.window()) &&
!window_util::HasDilation(conv.window());
@@ -161,12 +162,14 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
// The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
// out the shape of conv_result.
- Shape old_conv_shape = conv->shape().tuple_shapes(0);
-
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = CreateCudnnConvForward(
- old_conv_shape, new_input, new_kernel, new_conv_window,
- conv->convolution_dimension_numbers(), conv->feature_group_count());
+ std::vector<HloInstruction*> operands(conv->operands().begin(),
+ conv->operands().end());
+ operands[0] = new_input;
+ operands[1] = new_kernel;
+ auto new_conv = conv->parent()->AddInstruction(
+ conv->CloneWithNewOperands(conv->shape(), operands));
+ new_conv->set_window(new_conv_window);
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
@@ -242,10 +245,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// The shape of the backward_conv CustomCall is a tuple (conv_result,
// scratch_buffer). Extract out the shape of conv_result.
- Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
- HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
- backward_conv_shape, padded_input, output, new_backward_conv_window,
- backward_conv_dnums, backward_conv->feature_group_count());
+ HloInstruction* new_backward_conv =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ backward_conv->shape(), {padded_input, output}));
+ new_backward_conv->set_window(new_backward_conv_window);
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -308,9 +311,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* output = backward_conv->mutable_operand(0);
HloInstruction* filter = backward_conv->mutable_operand(1);
- HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
- new_backward_conv_shape, output, filter, new_backward_conv_window,
- backward_conv_dnums, backward_conv->feature_group_count());
+ HloInstruction* new_backward_conv_call =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ ShapeUtil::MakeTupleShape(
+ {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
+ {output, filter}));
+ new_backward_conv_call->set_window(new_backward_conv_window);
// The CustomCall created above returns a tuple (conv_result, scratch_memory).
// Extract out the two elements.
@@ -380,7 +386,8 @@ StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) {
}
for (HloInstruction* instruction : convs) {
const auto& target = instruction->custom_call_target();
- if (target == kCudnnConvForwardCallTarget) {
+ if (target == kCudnnConvForwardCallTarget ||
+ target == kCudnnConvBiasActivationForwardCallTarget) {
changed |= CanonicalizeForwardConvolution(instruction);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
changed |= CanonicalizeBackwardFilterConvolution(instruction);
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
index a622e894ed..25cdf64c4c 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -24,7 +24,7 @@ namespace gpu {
// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
// inserts Pad instructions before Convolution instructions with uncanonicalized
// padding, so that they can be lowered to cuDNN convolution.
-class PadInsertion : public HloPassInterface {
+class PadInsertion : public HloModulePass {
public:
absl::string_view name() const override { return "pad insertion"; }
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index db4a33dc56..a725533567 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -25,15 +25,17 @@ filegroup(
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
cc_library(
name = "gpu_codegen_test",
testonly = True,
srcs = ["gpu_codegen_test.cc"],
hdrs = ["gpu_codegen_test.h"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:gpu_plugin",
@@ -48,9 +50,7 @@ cc_library(
tf_cc_test(
name = "gpu_copy_test",
srcs = ["gpu_copy_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -67,9 +67,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_ftz_test",
srcs = ["gpu_ftz_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/core:test_main",
@@ -79,9 +77,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_index_test",
srcs = ["gpu_index_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -102,9 +98,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_infeed_test",
srcs = ["infeed_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -125,9 +119,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_kernel_tiling_test",
srcs = ["gpu_kernel_tiling_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo",
@@ -142,7 +134,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_ldg_test",
srcs = ["gpu_ldg_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -159,9 +151,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_noalias_test",
srcs = ["gpu_noalias_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:literal",
@@ -178,9 +168,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_fusion_test",
srcs = ["gpu_fusion_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
@@ -194,9 +182,7 @@ tf_cc_test(
tf_cc_test(
name = "gpu_unrolling_test",
srcs = ["gpu_unrolling_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_module_config",
@@ -211,9 +197,7 @@ tf_cc_test(
name = "gpu_alignment_test",
testonly = True,
srcs = ["gpu_alignment_test.cc"],
- tags = [
- "requires-gpu-sm35",
- ],
+ tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:gpu_plugin",
@@ -225,3 +209,17 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "cudnn_fused_convolution_rewriter_test",
+ srcs = ["cudnn_fused_convolution_rewriter_test.cc"],
+ tags = tf_cuda_tests_tags(),
+ deps = [
+ ":gpu_codegen_test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
new file mode 100644
index 0000000000..5632cac186
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc
@@ -0,0 +1,283 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/strings/str_replace.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class CudnnFusedConvolutionRewriterTest : public HloTestBase {
+ protected:
+ string GetOptimizedHlo(absl::string_view hlo_string) {
+ return backend()
+ .compiler()
+ ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest())
+ .ConsumeValueOrDie(),
+ backend().default_stream_executor(),
+ backend().memory_allocator())
+ .ConsumeValueOrDie()
+ ->ToString();
+ }
+
+ void TestMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convForward"))
+ << optimized_hlo_string;
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo_string.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo_string;
+ EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01}))
+ << optimized_hlo_string;
+ }
+ }
+
+ void TestNotMatchWithAllTypes(absl::string_view hlo_string) {
+ for (absl::string_view type : {"f16", "f32", "f64"}) {
+ const string hlo_with_new_type =
+ absl::StrReplaceAll(hlo_string, {{"TYPE", type}});
+ string optimized_hlo = GetOptimizedHlo(hlo_with_new_type);
+ EXPECT_NE(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convForward"))
+ << optimized_hlo;
+ EXPECT_EQ(absl::string_view::npos,
+ optimized_hlo.find("__cudnn$convBiasActivationForward"))
+ << optimized_hlo;
+ }
+ }
+};
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) {
+ // max(0, conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) {
+ // max(0, conv(x, w) + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ bias = TYPE[64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) {
+ // max(0, conv(x, w) + side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) {
+ // max(0, conv(x, w) + side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) {
+ // max(0, 0.999994934 * conv(x, w));
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={}
+ scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv)
+ ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) {
+ // max(0, conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest,
+ TestScaledConvAndScaledSideInputWithBias) {
+ // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias);
+ TestMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+ alpha_conv_scalar = TYPE[] constant(0.999994934)
+ alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={}
+ alpha_side_input_scalar = TYPE[] constant(0.899994934)
+ alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input = TYPE[1,3,3,64] parameter(2)
+ bias = TYPE[64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv)
+ scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input)
+ broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3}
+ add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias)
+ add2 = TYPE[1,3,3,64] add(add1, scaled_side_input)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) {
+ // max(0.1, conv(x, w)) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ point_one = TYPE[] constant(0.1)
+ point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={}
+
+ input = TYPE[1,17,9,9] parameter(0)
+ filter = TYPE[3,3,17,32] parameter(1)
+
+ conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1
+ ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv)
+ })");
+}
+
+TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) {
+ // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match.
+ TestNotMatchWithAllTypes(R"(
+ HloModule Test
+
+ ENTRY Test {
+ zero = TYPE[] constant(0)
+ zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={}
+
+ input = TYPE[1,3,3,64] parameter(0)
+ filter = TYPE[3,3,64,64] parameter(1)
+ side_input1 = TYPE[1,3,3,64] parameter(2)
+ side_input2 = TYPE[1,3,3,64] parameter(3)
+
+ conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1
+ add1 = TYPE[1,3,3,64] add(conv, side_input2)
+ add2 = TYPE[1,3,3,64] add(add1, side_input1)
+ ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2)
+ })");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index 40183de96e..9a61f8ac5a 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -26,9 +26,6 @@ limitations under the License.
namespace xla {
namespace {
-using ::testing::Eq;
-using ::testing::HasSubstr;
-
class WhileTransformerTest : public HloTestBase {
protected:
WhileTransformerTest()
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index e0f3a7e0e2..2bd04259c0 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -736,4 +736,209 @@ HeapSimulator::Result LazyBestFitHeap::Finish() {
return result_;
}
+void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer,
+ int64 size) {
+ // Degenerate case: 0-sized buffers are always allocated at offset 0.
+ if (size == 0) {
+ result_.chunk_map.emplace(buffer, Chunk{0, 0});
+ return;
+ }
+ auto emplace_result = buffer_intervals_.emplace(
+ buffer, BufferInterval{buffer, size, current_time_, -1});
+ DCHECK(emplace_result.second);
+ ++current_time_;
+}
+
+void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer,
+ int64 size) {
+ // Degenerate case: 0-sized buffers are always allocated at offset 0.
+ if (size == 0) {
+ return;
+ }
+ BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer);
+ DCHECK_EQ(buffer_interval.buffer, buffer);
+ DCHECK_EQ(buffer_interval.size, size);
+ DCHECK_EQ(buffer_interval.end, -1);
+ buffer_interval.end = current_time_;
+ ++current_time_;
+}
+
+namespace {
+
+// Node in BufferIntervalTree that stores the alloc and free times of a buffer,
+// and the chunk assigned to it.
+struct BufferIntervalTreeNode {
+ // Alloc time.
+ int64 start;
+ // Free time.
+ int64 end;
+ // Maximum free time of all nodes in the subtree where this node is the root.
+ int64 subtree_end;
+ // Allocated chunk for the buffer.
+ HeapSimulator::Chunk chunk;
+ // Left child.
+ BufferIntervalTreeNode* left;
+ // Right child.
+ BufferIntervalTreeNode* right;
+};
+
+// An interval tree that can query buffers overlapping in time.
+class BufferIntervalTree {
+ public:
+ explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {}
+
+ using Chunk = HeapSimulator::Chunk;
+
+ // Adds a buffer to the interval tree, with the time interval and allocated
+ // chunk specified.
+ void Add(int64 start, int64 end, const Chunk& chunk) {
+ int index = node_count_;
+ DCHECK_LT(index, node_storage_.size());
+ ++node_count_;
+
+ node_storage_[index] =
+ BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr};
+
+ if (index == 0) {
+ // This is root.
+ return;
+ }
+
+ BufferIntervalTreeNode* parent = &node_storage_[0];
+ while (true) {
+ parent->subtree_end = std::max(parent->subtree_end, end);
+ if (parent->start > start) {
+ if (parent->left == nullptr) {
+ parent->left = &node_storage_[index];
+ return;
+ }
+ parent = parent->left;
+ } else {
+ if (parent->right == nullptr) {
+ parent->right = &node_storage_[index];
+ return;
+ }
+ parent = parent->right;
+ }
+ }
+ }
+
+ // Returns vector of allocated chunks that overlap with the given time
+ // interval.
+ std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) {
+ std::vector<Chunk> result;
+ if (node_count_ == 0) {
+ return result;
+ }
+ std::vector<BufferIntervalTreeNode*> visiting_stack;
+ visiting_stack.push_back(&node_storage_[0]);
+ while (!visiting_stack.empty()) {
+ BufferIntervalTreeNode* top = visiting_stack.back();
+ visiting_stack.pop_back();
+ if (start > top->subtree_end) {
+ continue;
+ }
+ if (top->left != nullptr) {
+ visiting_stack.push_back(top->left);
+ }
+ if (top->start <= end && top->end >= start) {
+ result.push_back(top->chunk);
+ }
+ if (end < top->start) {
+ continue;
+ }
+ if (top->right != nullptr) {
+ visiting_stack.push_back(top->right);
+ }
+ }
+ return result;
+ }
+
+ private:
+ int64 node_count_ = 0;
+ std::vector<BufferIntervalTreeNode> node_storage_;
+};
+
+} // namespace
+
+HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() {
+ std::vector<BufferInterval> sorted_buffer_intervals;
+ for (auto& entry : buffer_intervals_) {
+ sorted_buffer_intervals.push_back(entry.second);
+ }
+ std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(),
+ [](const BufferInterval& x, const BufferInterval& y) {
+ if (x.size != y.size) {
+ return x.size > y.size;
+ }
+ if (x.end - x.start != y.end - y.start) {
+ return x.end - x.start > y.end - y.start;
+ }
+ return x.buffer->id() < y.buffer->id();
+ });
+
+ BufferIntervalTree interval_tree(sorted_buffer_intervals.size());
+ for (auto& buffer_interval : sorted_buffer_intervals) {
+ auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime(
+ buffer_interval.start, buffer_interval.end);
+ std::sort(
+ chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(),
+ [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; });
+
+ // Find the minimum free chunk that can hold this buffer.
+ Chunk min_fit_chunk{-1, INT64_MAX};
+ auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) {
+ if (free_size < buffer_interval.size) {
+ return;
+ }
+
+ if (free_size < min_fit_chunk.size) {
+ min_fit_chunk = {free_offset, free_size};
+ }
+ };
+
+ int64 offset = 0;
+ for (auto& chunk : chunks_overlapping_in_time) {
+ if (offset < chunk.offset) {
+ use_free_chunk_if_smaller(offset, chunk.offset - offset);
+ }
+ offset =
+ std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_));
+ }
+ use_free_chunk_if_smaller(offset, result_.heap_size - offset);
+
+ if (min_fit_chunk.offset == -1) {
+ // Increase the heap size to fit in the last free chunk.
+ result_.heap_size = offset + buffer_interval.size;
+ min_fit_chunk = {offset, buffer_interval.size};
+ }
+
+ min_fit_chunk.size = buffer_interval.size;
+ const auto emplace_result =
+ result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk);
+ DCHECK(emplace_result.second);
+
+ interval_tree.Add(buffer_interval.start, buffer_interval.end,
+ min_fit_chunk);
+ }
+ return result_;
+}
+
+HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() {
+ DCHECK(!algorithms_.empty());
+ std::vector<Result> results(algorithms_.size());
+ int64 min_size = INT64_MAX;
+ int min_size_index = -1;
+ for (int i = 0; i < algorithms_.size(); ++i) {
+ results[i] = algorithms_[i]->Finish();
+ if (results[i].heap_size < min_size) {
+ min_size = results[i].heap_size;
+ min_size_index = i;
+ }
+ }
+
+ DCHECK_GE(min_size_index, 0);
+ return results[min_size_index];
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index ffbf947d5a..7d6dcc0dc9 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -351,6 +351,68 @@ class LazyBestFitHeap : public HeapAlgorithm {
std::set<Chunk, OrderChunkByIncreasingSize> free_;
};
+// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers,
+// then allocates them in decreasing sizes regardless of the alloc/free time. It
+// internally tracks the allocated buffers and their live intervals; when
+// allocating a buffer, it finds the best-fit free chunk during its live
+// interval.
+class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
+ public:
+ GlobalDecreasingSizeBestFitHeap(int64 alignment) : alignment_(alignment) {}
+ ~GlobalDecreasingSizeBestFitHeap() override {}
+
+ void Alloc(const BufferValue* buffer, int64 size) override;
+ void Free(const BufferValue* buffer, int64 size) override;
+ Result Finish() override;
+
+ private:
+ int64 alignment_;
+ Result result_;
+
+ // The current time represented as an integer. It increments by 1 at each
+ // Alloc or Free call.
+ int64 current_time_ = 0;
+
+ // BufferInterval stores a buffer's size and time interval.
+ struct BufferInterval {
+ const BufferValue* buffer;
+ int64 size;
+ // Alloc time of the buffer.
+ int64 start;
+ // Free time of the buffer.
+ int64 end;
+ };
+ tensorflow::gtl::FlatMap<const BufferValue*, BufferInterval>
+ buffer_intervals_;
+};
+
+// A heap algorithm that chooses the best results from other algorithms added to
+// it.
+class ChooseBestHeapAlgorithm : public HeapAlgorithm {
+ public:
+ ChooseBestHeapAlgorithm(
+ std::unique_ptr<std::vector<std::unique_ptr<HeapAlgorithm>>> algorithms)
+ : algorithms_(std::move(*algorithms)) {}
+ ~ChooseBestHeapAlgorithm() override {}
+
+ void Alloc(const BufferValue* buffer, int64 size) override {
+ for (auto& algorithm : algorithms_) {
+ algorithm->Alloc(buffer, size);
+ }
+ }
+
+ void Free(const BufferValue* buffer, int64 size) override {
+ for (auto& algorithm : algorithms_) {
+ algorithm->Free(buffer, size);
+ }
+ }
+
+ Result Finish() override;
+
+ private:
+ std::vector<std::unique_ptr<HeapAlgorithm>> algorithms_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 957c4a6891..191fbf8194 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -1021,5 +1021,135 @@ TEST_F(LazyBestFitHeapTest, Alignment) {
EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset);
}
+class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {};
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) {
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(0, result.heap_size);
+ EXPECT_EQ(0, result.chunk_map.size());
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) {
+ // space
+ // ^
+ // | +---a---+
+ // | +-------+
+ // | +---c---+
+ // | +-------+
+ // | | b |
+ // | +-------+
+ // | +-------+
+ // | | |
+ // | | d |
+ // | +-------+
+ // -----------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 30);
+ heap.Alloc(buffer_c_, 20);
+ heap.Alloc(buffer_d_, 40);
+ heap.Free(buffer_a_, 10);
+ heap.Free(buffer_b_, 30);
+ heap.Free(buffer_c_, 20);
+ heap.Free(buffer_d_, 40);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(100, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
+
+ EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset);
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) {
+ // space
+ // ^
+ // | +-------+
+ // | +---b---+
+ // | +-------+
+ // | | |
+ // | | d |
+ // | +---a---+ +-------+
+ // |
+ // | +-------+
+ // | | |
+ // | | c |
+ // | | |
+ // | +-------+
+ // ---------------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 20);
+ heap.Alloc(buffer_c_, 50);
+ heap.Free(buffer_a_, 10);
+ heap.Alloc(buffer_d_, 40);
+ heap.Free(buffer_b_, 20);
+ heap.Free(buffer_c_, 50);
+ heap.Free(buffer_d_, 40);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(120, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size);
+
+ EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset);
+}
+
+TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) {
+ // space
+ // ^
+ // | +-------+
+ // | +---b---+
+ // | +-------+
+ // | | d |
+ // | +--a--+ +-------+
+ // | +-------+
+ // | | |
+ // | | c |
+ // | +-------+
+ // | +-------+
+ // | | |
+ // | | e |
+ // | | |
+ // | +-------+
+ // ---------------------> time
+ GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1);
+ heap.Alloc(buffer_a_, 10);
+ heap.Alloc(buffer_b_, 20);
+ heap.Alloc(buffer_c_, 40);
+ heap.Free(buffer_a_, 10);
+ heap.Alloc(buffer_d_, 30);
+ heap.Alloc(buffer_e_, 50);
+ heap.Free(buffer_b_, 20);
+ heap.Free(buffer_c_, 40);
+ heap.Free(buffer_d_, 30);
+ heap.Free(buffer_e_, 50);
+
+ const HeapSimulator::Result result = heap.Finish();
+ EXPECT_EQ(140, result.heap_size);
+ EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size);
+ EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size);
+ EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size);
+ EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size);
+
+ EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset);
+ EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset);
+ EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset);
+ EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset);
+ EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 93ec2c9438..b19ec12638 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -309,6 +309,13 @@ message HeapSimulatorTrace {
bool whole_module_simulation = 2;
}
+// An abstraction representing a set of HLO module built to run concurrently
+// across different devices.
+message HloModuleGroupProto {
+ string name = 1;
+ repeated HloModuleProto hlo_modules = 2;
+}
+
// Serialization of BufferAssignment.
message BufferAssignmentProto {
// Alias represents a source LogicalBuffer, and the buffer location that
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 233d2199d1..0e5920af7a 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -272,18 +272,19 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
<< "instruction " << instruction->name()
<< " has control successors and cannot be removed";
- TF_RET_CHECK(instruction_iterators_.count(instruction) != 0);
- auto inst_it = instruction_iterators_.at(instruction);
- (*inst_it)->set_parent(nullptr);
- instructions_.erase(inst_it);
+ auto inst_it = instruction_iterators_.find(instruction);
+ TF_RET_CHECK(inst_it != instruction_iterators_.end());
+ (*inst_it->second)->set_parent(nullptr);
+ instructions_.erase(inst_it->second);
+ instruction_iterators_.erase(inst_it);
return Status::OK();
}
-void HloComputation::set_root_instruction(
- HloInstruction* new_root_instruction) {
+void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
+ bool accept_different_shape) {
// The shape of the root (ignoring layout) is an invariant of the computation
// for non-fusion cases.
- if (!IsFusionComputation()) {
+ if (!IsFusionComputation() && !accept_different_shape) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
<< new_root_instruction->shape() << " is incompatible with "
@@ -562,9 +563,11 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
- return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
- &instructions, root,
- /*fusion_instruction=*/nullptr));
+ auto computation = absl::WrapUnique(
+ new HloComputation(proto.name(), parameter_count, &instructions, root,
+ /*fusion_instruction=*/nullptr));
+ computation->unique_id_ = proto.id();
+ return std::move(computation);
}
void HloComputation::FuseInstructionsInto(
@@ -914,13 +917,14 @@ std::unique_ptr<HloComputation> HloComputation::Clone(
return CloneWithReplacements(
/*replacements=*/std::unordered_map<const HloInstruction*,
std::unique_ptr<HloInstruction>>(),
- context, suffix);
+ /*extras=*/{}, context, suffix);
}
std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloCloneContext* context, const string& suffix) {
+ absl::Span<HloInstruction*> extras, HloCloneContext* context,
+ const string& suffix) {
std::unique_ptr<HloCloneContext> context_ptr;
if (context == nullptr) {
context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
@@ -942,6 +946,9 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
std::vector<HloInstruction*> postorder;
+ for (HloInstruction* instr : extras) {
+ postorder.push_back(instr);
+ }
for (HloInstruction* instr : MakeInstructionPostOrder()) {
if (HloInstruction* replacement = replace(instr)) {
postorder.push_back(replacement);
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 91c5234a6f..936a53bd7e 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -134,9 +134,11 @@ class HloComputation {
Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction);
// Set the root of the computation to the given instruction. The instruction
- // must have already been added to the computation and have the same shape as
- // the result of the computation for non fusion computations.
- void set_root_instruction(HloInstruction* new_root_instruction);
+ // must have already been added to the computation. In addition it must have
+ // the same shape as the result of the computation for non fusion
+ // computations, except if accept_different_shape is set to true.
+ void set_root_instruction(HloInstruction* new_root_instruction,
+ bool accept_different_shape = false);
// Return the root instruction of the computation. The root instruction is the
// instruction which produces the output of the computation.
@@ -225,7 +227,7 @@ class HloComputation {
void UpdateReachabilityThroughInstruction(
const HloInstruction* instruction, HloReachabilityMap* reachability_map);
- int64 instruction_count() const { return instructions_.size(); }
+ int64 instruction_count() const { return instruction_iterators_.size(); }
// Creates and returns a list of the embedded computations called by this
// computation. This includes all embedded computations called directly or
@@ -331,10 +333,13 @@ class HloComputation {
//
// If replacements maps a key to nullptr, we remove that instruction from the
// new computation.
+ // If additional instructions are used by instructions in replacement map,
+ // they must be passed in post-order in the extras span.
std::unique_ptr<HloComputation> CloneWithReplacements(
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
replacements,
- HloCloneContext* context = nullptr, const string& suffix = "clone");
+ absl::Span<HloInstruction*> extras, HloCloneContext* context = nullptr,
+ const string& suffix = "clone");
// Returns true if the given instruction can be removed from the computation.
// Parameter instructions cannot be removed without violating invariants of
@@ -434,7 +439,7 @@ class HloComputation {
// instruction pointer to location in the list for fast lookup.
using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
InstructionList instructions_;
- std::unordered_map<const HloInstruction*, InstructionList::iterator>
+ tensorflow::gtl::FlatMap<const HloInstruction*, InstructionList::iterator>
instruction_iterators_;
std::vector<HloInstruction*> param_instructions_;
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h
index 4557983a9c..4a624cc7b8 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding.h
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h
@@ -23,7 +23,7 @@ namespace xla {
// A pass which performs constant folding in order to avoid unnecessary
// computation on constants.
-class HloConstantFolding : public HloPassInterface {
+class HloConstantFolding : public HloModulePass {
public:
absl::string_view name() const override { return "constant_folding"; }
diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
index 4da42844bd..3e0def5d26 100644
--- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -37,7 +37,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using HloConstantFoldingTest = HloTestBase;
+using HloConstantFoldingTest = HloVerifiedTestBase;
TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
HloComputation::Builder builder(TestName());
@@ -52,7 +52,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -73,7 +73,7 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -94,7 +94,7 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
EXPECT_THAT(computation->root_instruction(), op::Convert(input));
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
EXPECT_THAT(computation->root_instruction(), op::Constant());
@@ -134,7 +134,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -161,7 +161,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -186,7 +186,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
auto computation = module->AddEntryComputation(builder.Build());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module));
EXPECT_TRUE(result);
HloInstruction* root = computation->root_instruction();
@@ -219,28 +219,27 @@ const char* const kConstantFoldReduce = R"(
})";
TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
+ ParseAndVerifyModule(kConstantFoldReduce);
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_TRUE(result);
- EXPECT_EQ(6, module->entry_computation()
+ EXPECT_EQ(6, module()
+ .entry_computation()
->root_instruction()
->literal()
.GetFirstElement<int32>());
}
TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- ParseHloString(kConstantFoldReduce));
- HloInstruction* add = module->computations().begin()->root_instruction();
+ ParseAndVerifyModule(kConstantFoldReduce);
+ HloInstruction* add = module().computations().begin()->root_instruction();
LayoutUtil::ClearLayout(add->mutable_shape());
HloConstantFolding const_folder;
- TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(&module()));
EXPECT_FALSE(result);
- EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+ EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce());
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index b76c50bb5b..b2005d3c21 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@@ -201,6 +202,44 @@ StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
HloInstruction::CreateMap(map_shape, operands, map_computation));
}
+StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
+ HloInstruction* init_value,
+ HloOpcode binary_opcode,
+ HloModule* module) {
+ DCHECK_NE(nullptr, module);
+ std::vector<int64> all_dims(ShapeUtil::Rank(operand->shape()));
+ std::iota(all_dims.begin(), all_dims.end(), 0);
+
+ auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {});
+ HloComputation* reduce_computation;
+ {
+ HloComputation::Builder b(operand->name() + ".reduce_sub_computation");
+ auto lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto rhs = b.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ b.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs));
+ reduce_computation = module->AddEmbeddedComputation(b.Build());
+ }
+
+ return operand->parent()->AddInstruction(HloInstruction::CreateReduce(
+ scalar_shape, operand, init_value, all_dims, reduce_computation));
+}
+
+StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) {
+ HloComputation* computation = pred->parent();
+ DCHECK_EQ(computation, on_true->parent());
+ DCHECK_EQ(computation, on_false->parent());
+ TF_ASSIGN_OR_RETURN(Shape select_shape,
+ ShapeInference::InferTernaryOpShape(
+ HloOpcode::kSelect, pred, on_true, on_false));
+ return computation->AddInstruction(HloInstruction::CreateTernary(
+ select_shape, HloOpcode::kSelect, pred, on_true, on_false));
+}
+
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index b22058abb4..8e5ddbbd50 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
+#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -107,6 +108,35 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
HloComputation* map_computation);
+// Creates a Reduce HLO instruction and adds it to the computation containing
+// the operand. This will create the sub-computation needed for the reduction in
+// the given module. binary_opcode should represent a binary operation.
+StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
+ HloInstruction* init_value,
+ HloOpcode binary_opcode,
+ HloModule* module);
+
+// Creates a Select HLO instruction and adds it to the computation containing
+// the predicate. The on_true and on_false instructions must also be contained
+// in the same computation.
+StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false);
+
+// Creates an R1 Constant HLO instruction of the given PrimitiveType with the
+// given values and adds it to the given computation.
+template <typename NativeT>
+StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
+ PrimitiveType type,
+ absl::Span<const NativeT> values) {
+ Literal literal = LiteralUtil::CreateR1<NativeT>(values);
+ if (literal.shape().element_type() != type) {
+ TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
+ }
+ return computation->AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+}
+
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index a28c03599a..e4857fd3fd 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -25,7 +25,7 @@ namespace xla {
// and identical instructions with the same operands are commoned. The pass
// iterates over the instructions in topological order which enables the pass to
// find arbitrarily large common expressions.
-class HloCSE : public HloPassInterface {
+class HloCSE : public HloModulePass {
public:
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored.
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
index 1fe69b1395..4012042672 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -33,7 +33,7 @@ namespace xla {
//
// This pass does not remove dead parameter instructions, as parameter
// instructions cannot be deleted.
-class HloDCE : public HloPassInterface {
+class HloDCE : public HloModulePass {
public:
~HloDCE() override {}
absl::string_view name() const override { return "dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
index d36631fc2f..c0bf1b9e16 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h
@@ -30,7 +30,7 @@ namespace xla {
// used to break an HLO graph edge connecting two instructions with different
// sharding. If a set of connected instructions have all the same sharding, no
// kDomain instruction will be placed.
-class HloDomainIsolator : public HloPassInterface {
+class HloDomainIsolator : public HloModulePass {
public:
// Creates a new kDomain instruction for the edge between the use instruction
// (the first HloInstruction argument), and the operand instruction (the
diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h
index 97bc8ef604..0fc30fb86c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_remover.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h
@@ -26,7 +26,7 @@ namespace xla {
// Removes all the kDomain instructions of a given kind from the input module,
// and calls the normalizer to propagate the properties on the possibly new born
// instructions.
-class HloDomainRemover : public HloPassInterface {
+class HloDomainRemover : public HloModulePass {
public:
// Creates a new HloDomainRemover object tasked at removing all the kDomain
// instructions of a given kind.
diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
index 81d6d69a8c..bea5cba38d 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h
@@ -29,7 +29,7 @@ namespace xla {
// Verifies that the domain instructions are consistent, and the each domain is
// surrounded by the same metadata.
-class HloDomainVerifier : public HloPassInterface {
+class HloDomainVerifier : public HloModulePass {
public:
HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {}
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
index 44ded2c2fa..4d2a942925 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h
@@ -25,7 +25,7 @@ namespace xla {
// inserting Convert ops. This allows a backend to support an element type while
// only actually implementing the Convert op for that element type. This is
// generally not the fastest approach, but it works.
-class HloElementTypeConverter : public HloPassInterface {
+class HloElementTypeConverter : public HloModulePass {
public:
// eliminate_type is the type to eliminate as the input or output of ops,
// using Convert ops to replace it with replace_with_type.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 064b86493d..d7c39b2778 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -496,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) {
return Status::OK();
}
+Status HloEvaluator::HandleReal(HloInstruction* real) {
+ auto operand = real->operand(0);
+ switch (operand->shape().element_type()) {
+ case BF16: {
+ auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>(
+ real, [](bfloat16 elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case C64: {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ real, [](complex64 elem_operand) { return std::real(elem_operand); },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F16: {
+ auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>(
+ real, [](Eigen::half elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F32: {
+ auto result_or = ElementWiseUnaryOpImpl<float, float>(
+ real, [](float elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ case F64: {
+ auto result_or = ElementWiseUnaryOpImpl<double, double>(
+ real, [](double elem_operand) { return elem_operand; },
+ GetEvaluatedLiteralFor(operand));
+ TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or));
+ break;
+ }
+ default:
+ LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: "
+ << PrimitiveType_Name(operand->shape().element_type());
+ }
+
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleImag(HloInstruction* imag) {
+ auto result_or = ElementWiseUnaryOpImpl<float, complex64>(
+ imag, [](complex64 elem_operand) { return std::imag(elem_operand); },
+ GetEvaluatedLiteralFor(imag->operand(0)));
+
+ TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or));
+ return Status::OK();
+}
+
Status HloEvaluator::HandleCompare(HloInstruction* compare) {
HloOpcode opcode = compare->opcode();
auto lhs = compare->operand(0);
@@ -1173,80 +1228,85 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort,
TF_RET_CHECK(
ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape()))
<< "Sort keys and values must have the same dimensions";
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for rank-1 and rank-2 shapes, rank is: "
- << rank;
TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort";
- // We need to sort and array of keys and an array of values, where the
+ // We need to sort an array of keys and an array of values, where the
// sorted order of the values is determined by the keys. The simplest(?)
// way to do this is to go to an array-of-pairs representation, sort the
// array using the keys, and then go back to pair-of-arrays.
VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
VLOG(3) << "HandleSort values_literal: " << values_literal.ToString();
- auto sort_r1 = [](const Literal& keys_literal,
- const Literal& values_literal) {
- const auto& keys_data = keys_literal.data<KeyType>();
- const auto& values_data = values_literal.data<ValueType>();
-
- using kv_pair = std::pair<KeyType, ValueType>;
- std::vector<kv_pair> key_value_vector;
- CHECK_EQ(keys_data.size(), values_data.size());
- key_value_vector.reserve(keys_data.size());
- for (int i = 0; i < keys_data.size(); ++i) {
- key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i]));
- }
- std::sort(key_value_vector.begin(), key_value_vector.end(),
- [](const kv_pair& a, const kv_pair& b) {
- return SafeLess<KeyType>(a.first, b.first);
- });
- std::vector<KeyType> result_keys;
- std::vector<ValueType> result_values;
- for (const auto& key_value : key_value_vector) {
- result_keys.push_back(key_value.first);
- result_values.push_back(key_value.second);
- }
- Literal result_keys_literal(keys_literal.shape());
- result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys));
- Literal result_values_literal(values_literal.shape());
- result_values_literal.PopulateR1(
- absl::Span<const ValueType>(result_values));
- return std::make_pair(std::move(result_keys_literal),
- std::move(result_values_literal));
- };
-
- Literal result_tuple;
- if (rank == 1) {
- auto result_pair = sort_r1(keys_literal, values_literal);
- result_tuple =
- LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second});
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal keys_result_literal(keys_literal.shape());
- Literal values_result_literal(values_literal.shape());
- int64 r1_length = keys_literal.shape().dimensions(1);
- for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto keys_r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- TF_ASSIGN_OR_RETURN(auto values_r1_slice,
- values_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice);
- TF_ASSIGN_OR_RETURN(auto sorted_keys,
- r1_result_pair.first.Reshape({1, r1_length}));
- TF_ASSIGN_OR_RETURN(auto sorted_values,
- r1_result_pair.second.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
- sorted_keys, {0, 0}, {row, 0}, {1, r1_length}));
- TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
- sorted_values, {0, 0}, {row, 0}, {1, r1_length}));
- }
- result_tuple =
- LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
+ if (rank == 0) {
+ // Nothing to sort.
+ return LiteralUtil::MakeTuple({&keys_literal, &values_literal});
}
+ Literal keys_result_literal(keys_literal.shape());
+ Literal values_result_literal(values_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys_literal.shape(), zero_base,
+ AsInt64Slice(keys_literal.shape().dimensions()), increment,
+ [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the keys and values literals that correspond to
+ // exactly the row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto keys_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& keys_data = keys_to_sort.data<KeyType>();
+ TF_ASSIGN_OR_RETURN(auto values_to_sort,
+ values_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& values_data = values_to_sort.data<ValueType>();
+ using kv_pair = std::pair<KeyType, ValueType>;
+ std::vector<kv_pair> key_value_vector;
+ key_value_vector.reserve(keys_data.size());
+ for (int i = 0; i < keys_data.size(); ++i) {
+ key_value_vector.push_back(
+ std::make_pair(keys_data[i], values_data[i]));
+ }
+ std::sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<KeyType>(a.first, b.first);
+ });
+ std::vector<KeyType> result_keys;
+ std::vector<ValueType> result_values;
+ for (const auto& key_value : key_value_vector) {
+ result_keys.push_back(key_value.first);
+ result_values.push_back(key_value.second);
+ }
+ Literal sorted_keys(ShapeUtil::MakeShape(
+ keys_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_keys.PopulateR1(absl::Span<const KeyType>(result_keys));
+ Literal sorted_values(ShapeUtil::MakeShape(
+ values_literal.shape().element_type(), {sort_dim_elements}));
+ sorted_values.PopulateR1(absl::Span<const ValueType>(result_values));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ std::vector<int64> start_indices(rank, 0);
+ TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped,
+ sorted_keys.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom(
+ sorted_keys_reshaped, start_indices, indices, slice_dimensions));
+ TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped,
+ sorted_values.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom(
+ sorted_values_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+
+ Literal result_tuple;
+ result_tuple =
+ LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal});
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
return std::move(result_tuple);
}
@@ -1292,15 +1352,6 @@ StatusOr<Literal> EvaluateSort(HloInstruction* sort,
} // namespace
Status HloEvaluator::HandleSort(HloInstruction* sort) {
- const int64 sort_dim = sort->dimensions(0);
- const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape());
- if (sort_dim != rank - 1) {
- return Unimplemented(
- "Trying to sort along dimension %d, which is not the last "
- "dimension",
- sort_dim);
- }
-
if (!ShapeUtil::IsTuple(sort->shape())) {
return DefaultAction(sort);
} else {
@@ -1339,6 +1390,12 @@ Status HloEvaluator::Preprocess(HloInstruction* hlo) {
Status HloEvaluator::Postprocess(HloInstruction* hlo) {
VLOG(2) << "Finished visiting " << hlo->ToString()
<< "; evaluated value is: " << GetEvaluatedLiteralFor(hlo).ToString();
+ // Out of convenience the literal may have been produced with a different
+ // layout. Relayout as indicated by the HLO instruction.
+ if (!LayoutUtil::LayoutsInShapesEqual(GetEvaluatedLiteralFor(hlo).shape(),
+ hlo->shape())) {
+ evaluated_.at(hlo) = evaluated_.at(hlo).Relayout(hlo->shape());
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 21e676d671..6c2662ebae 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -184,6 +184,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleSort(HloInstruction* sort) override;
+ Status HandleReal(HloInstruction* real) override;
+
+ Status HandleImag(HloInstruction* imag) override;
+
Status HandleReduce(HloInstruction* reduce) override;
// Returns the already-evaluated literal result for the instruction.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index 16411eb078..cee11a8a21 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -66,6 +66,20 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
.ConsumeValueOrDie();
}
+ // Evaluate function that takes in a local module instead of using module_
+ // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is
+ // removed, this should be the default Evaluate function.
+ Literal EvaluateWithModule(
+ HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
+ if (use_bfloat16_) {
+ // In BF16 mode, we convert all F32 type to BF16 and evaluate the module.
+ auto type_converter = HloElementTypeConverter(F32, BF16);
+ type_converter.Run(module).ValueOrDie();
+ }
+ return evaluator_->Evaluate(*module->entry_computation(), arg_literals)
+ .ConsumeValueOrDie();
+ }
+
std::unique_ptr<HloEvaluator> evaluator_;
void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
@@ -2530,6 +2544,114 @@ ENTRY main {
expected, Evaluate({&operand, &scatter_indices, &updates})));
}
+TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter_NegativeIndices
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ // No updates should happen for the negative indices.
+ Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2});
+ Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
+ EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) {
+ const string hlo_text = R"(
+HloModule BatchDynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ // No updates should happen for the OOB indices.
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
+ EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
+TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd_OobUpdateWindow
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[1,2] parameter(1)
+ updates = s32[1,2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_text));
+ Literal operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
+ // Given the update window size of 2,2 and the index of 0,2, the update window
+ // will be OOB. So, nothing should be updated.
+ Literal expected = operand.Clone();
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ expected, EvaluateWithModule(module.get(),
+ {&operand, &scatter_indices, &updates})));
+}
+
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise comparison with 2 bfloat16 operands.
TEST_P(HloEvaluatorTest, DoesCompareBF16) {
@@ -2570,6 +2692,25 @@ ENTRY main {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, Evaluate({&arg})));
}
+TEST_P(HloEvaluatorTest, SliceWithDifferentLayout) {
+ // Regression test for b/114735354.
+ const string hlo_text = R"(
+HloModule SliceWithDifferentLayout
+
+ENTRY main {
+ arg = f32[2,2,2]{0,1,2} parameter(0)
+ ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
+}
+)";
+ ParseAndVerifyModule(hlo_text);
+
+ Literal arg = LiteralUtil::CreateR3WithLayout<float>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ LayoutUtil::MakeLayout({0, 1, 2}));
+ Literal actual = Evaluate({&arg});
+ EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
+}
+
INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest,
::testing::ValuesIn(use_bf16_params));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 7f090a52db..b2d12c94b8 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
+#include <cmath>
+
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
@@ -41,7 +43,9 @@ template <typename T>
using is_complex64_t = std::is_same<T, complex64>;
// It's UB to use std::sort with std::less<float>, because of NaNs. Define
-// "safe" less functions which are actually strict weak orders.
+// "safe" less functions which are actually strict weak orders. -NaN and NaN
+// should appear at the beginning and end of the ordering, and -0.0 should
+// appear before 0.0.
template <
typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr>
@@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
return a < b;
}
-template <typename NativeT,
- typename std::enable_if<
- std::is_floating_point<NativeT>::value ||
- std::is_same<NativeT, bfloat16>::value>::type* = nullptr>
+template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (std::isnan(b)) {
- return !std::isnan(a);
- } else {
- return a < b;
+ bool lhs_is_negative = std::signbit(a);
+ bool rhs_is_negative = std::signbit(b);
+ // If the signs are different, we can just compare the signs.
+ if (lhs_is_negative != rhs_is_negative) {
+ return lhs_is_negative && !rhs_is_negative;
+ }
+ bool lhs_nan = std::isnan(a);
+ bool rhs_nan = std::isnan(b);
+ // Exactly one number is nan?
+ if (lhs_nan != rhs_nan) {
+ if (lhs_nan) {
+ return lhs_is_negative;
+ }
+ return !rhs_is_negative;
}
+ return a < b;
}
-template <typename NativeT, typename std::enable_if<std::is_same<
- NativeT, Eigen::half>::value>::type* = nullptr>
+template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, Eigen::half>::value>::type* = nullptr>
bool SafeLess(const NativeT& a, const NativeT& b) {
- if (Eigen::half_impl::isnan(b)) {
- return !Eigen::half_impl::isnan(a);
- } else {
- return a < b;
- }
+ return SafeLess(static_cast<float>(a), static_cast<float>(b));
}
// Templated DfsHloVisitor for use by HloEvaluator.
@@ -78,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) {
// to this rule, notably:
// - HandleCompare and HandleIsFinite: where the resulting literal type is
// always boolean.
+// - HandleImag and HandleReal: where the resulting literal type is always float
+// and the operand is always complex, or real in the case of HandleReal.
// These operations are handled outside of the parent HloEvaluator handlers
// instead of from within TypedVisitor.
//
@@ -249,12 +262,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
TF_ASSIGN_OR_RETURN(Literal result,
parent_->GetEvaluatedLiteralFor(operand).Convert(
convert->shape().element_type()));
-
- if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
@@ -265,11 +273,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
parent_->GetEvaluatedLiteralFor(operand).BitcastConvert(
convert->shape().element_type()));
- if (LayoutUtil::LayoutsInShapesEqual(result.shape(), convert->shape())) {
- parent_->evaluated_[convert] = std::move(result);
- } else {
- parent_->evaluated_[convert] = result.Relayout(convert->shape().layout());
- }
+ parent_->evaluated_[convert] = std::move(result);
return Status::OK();
}
@@ -327,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleFloor<ReturnT>(floor);
}
- Status HandleImag(HloInstruction* imag) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag],
- ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) {
- return std::imag(elem_operand);
- }));
- return Status::OK();
- }
-
Status HandleLog(HloInstruction* log) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[log],
ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) {
@@ -682,14 +678,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
- Status HandleReal(HloInstruction* real) override {
- TF_ASSIGN_OR_RETURN(parent_->evaluated_[real],
- ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) {
- return std::real(elem_operand);
- }));
- return Status::OK();
- }
-
template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr>
Status HandleRemainder(HloInstruction* remainder) {
@@ -1536,47 +1524,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
!std::is_same<NativeT, bool>::value>::type* = nullptr>
Status HandleSort(HloInstruction* sort) {
auto keys = sort->operand(0);
- auto rank = ShapeUtil::Rank(keys->shape());
- TF_RET_CHECK(rank > 0 && rank <= 2)
- << "Sort is only supported for R1 and R2 shapes";
TF_RET_CHECK(sort->operand_count() == 1)
<< "Typed visitor does not support key-value sort";
const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
-
- auto sort_r1 = [this](const Literal& keys_literal) {
- VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString();
- const auto& keys_data = keys_literal.data<ReturnT>();
-
- std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end());
- std::sort(result_data.begin(), result_data.end(),
- [](const ReturnT& a, const ReturnT& b) {
- return SafeLess<ReturnT>(a, b);
- });
- Literal result_literal(keys_literal.shape());
- result_literal.PopulateR1(absl::Span<const ReturnT>(result_data));
- VLOG(3) << "HandleSort result_literal: " << result_literal.ToString();
- return result_literal;
- };
-
- if (rank == 1) {
- parent_->evaluated_[sort] = std::move(sort_r1(keys_literal));
- } else {
- // For R2 sort, the desired semantics are to sort each matrix row
- // independently.
- Literal result_literal(keys_literal.shape());
- int64 r1_length = keys->shape().dimensions(1);
- for (int64 row = 0; row < keys->shape().dimensions(0); ++row) {
- TF_ASSIGN_OR_RETURN(auto r1_slice,
- keys_literal.Slice({row, 0}, {row + 1, r1_length})
- .Reshape({r1_length}));
- auto r1_result = sort_r1(r1_slice);
- TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length}));
- TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
- r1_result, {0, 0}, {row, 0}, {1, r1_length}));
- }
- parent_->evaluated_[sort] = std::move(result_literal);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
+ int64 rank = ShapeUtil::Rank(keys->shape());
+ if (rank == 0) {
+ // Nothing to sort.
+ parent_->evaluated_[sort] = keys_literal.Clone();
+ return Status::OK();
}
+ Literal result_literal(keys_literal.shape());
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()),
+ increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from the literal that corresponds to exactly the
+ // row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ std::for_each(limit_indices.begin(), limit_indices.end(),
+ [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto row_to_sort,
+ keys_literal.Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ const auto& row_data = row_to_sort.data<NativeT>();
+
+ std::vector<NativeT> result_data(row_data.begin(), row_data.end());
+ std::sort(result_data.begin(), result_data.end(),
+ [](const NativeT& a, const NativeT& b) {
+ return SafeLess<NativeT>(a, b);
+ });
+ Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(),
+ {sort_dim_elements}));
+ sorted_row.PopulateR1(absl::Span<const NativeT>(result_data));
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped,
+ sorted_row.Reshape(slice_dimensions));
+ std::vector<int64> start_indices(rank, 0);
+ TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
+ sorted_row_reshaped, start_indices, indices, slice_dimensions));
+ return true;
+ }));
+ parent_->evaluated_[sort] = std::move(result_literal);
return Status::OK();
}
@@ -2274,19 +2270,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
// be 1.
int64 update_dim_size =
update_dim == -1 ? 1 : updates_shape.dimensions(update_dim);
- // Clamp the scatter index so that the scatter region fits in the
- // operand. input_scatter_index_clamped[i] =
- // clamp(input_scatter_index[i], 0,
- // operand_shape.dimensions(i) -
- // update_dim_size);
- input_scatter_index_clamped[i] =
- std::min(operand_shape.dimensions(i) - update_dim_size,
- std::max(0LL, input_scatter_index[i]));
+ // If any part of the update region is out-of-bounds, then do not
+ // perform any update on the input.
+ if ((input_scatter_index[i] < 0) ||
+ (input_scatter_index[i] >
+ operand_shape.dimensions(i) - update_dim_size)) {
+ return true;
+ }
}
for (int i = 0, e = input_index.size(); i < e; i++) {
- input_index[i] = input_scatter_index_clamped[i] + input_window_index[i];
- DCHECK_GE(input_index[i], 0);
- DCHECK_LT(input_index[i], operand_shape.dimensions(i));
+ input_index[i] = input_scatter_index[i] + input_window_index[i];
}
auto result_value_literal =
@@ -2350,8 +2343,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return operand_literal.Get<ReturnT>(operand_index);
};
- auto result = LiteralUtil::CreateFromDimensions(
- shape.element_type(), AsInt64Slice(shape.dimensions()));
+ Literal result(shape);
TF_RETURN_IF_ERROR(result.Populate<ReturnT>(func));
parent_->evaluated_[slice] = std::move(result);
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index d52f4e5a61..13a74fd8a1 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -123,8 +123,8 @@ class NodeFilter {
// We arbitrarily set this as the boundary between "large" and "small"
// instructions.
bool IsSmall(const HloInstruction* instr) {
- if (ShapeUtil::IsOpaque(instr->shape()) ||
- ShapeUtil::IsToken(instr->shape())) {
+ if (ShapeUtil::HasPrimitiveType(instr->shape(), OPAQUE) ||
+ ShapeUtil::HasPrimitiveType(instr->shape(), TOKEN)) {
return true;
}
return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096;
@@ -469,9 +469,8 @@ stylesheet=<
string graph_label =
StrCat(label_, "<br/>Computation ", computation_->name());
if (computation_->IsFusionComputation()) {
- StrAppend(&graph_label,
- StrCat(" (in fusion instruction ",
- computation_->FusionInstruction()->name(), ")"));
+ StrAppend(&graph_label, " (in fusion instruction ",
+ computation_->FusionInstruction()->name(), ")");
}
if (profile_ != nullptr) {
auto cycles = profile_->total_cycles_executed(*computation_);
@@ -1111,7 +1110,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
instr->metadata().source_line()));
}
- return StrJoin(lines, "<br/>");
+ return StrJoin(lines, "\n");
}
string HloDotDumper::GetInstructionNodeBackendConfig(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 85fa3ce964..ad58833e4d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
+ instruction->unique_id_ = proto.id();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,
@@ -2909,6 +2910,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
+bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
+ const HloInstruction* const& rhs) const {
+ if (rhs == nullptr) {
+ // Nothing compares less than nullptr.
+ return false;
+ }
+ if (lhs == nullptr) {
+ return true;
+ }
+ auto lhs_module = lhs->GetModule();
+ auto rhs_module = rhs->GetModule();
+ CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
+ (lhs_module != nullptr && rhs_module != nullptr));
+ if (lhs_module != nullptr &&
+ lhs_module->unique_id() != rhs_module->unique_id()) {
+ return lhs_module->unique_id() < rhs_module->unique_id();
+ }
+ return lhs->unique_id() < rhs->unique_id();
+}
+
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 4f6cac1396..d615df0831 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1616,6 +1616,10 @@ class HloInstruction {
InstructionVector operands_;
// The set of control predecessors of this instruction.
+ // Note that the order of the instructions in the vector influences the order
+ // computed in HloComputation::ComputeInstructionPostOrder, which may
+ // influence the result of the compilation by changing the scheduling. We are
+ // not sure if it matters.
std::vector<HloInstruction*> control_predecessors_;
// The users of this instruction. Users are HLOs where this instruction is an
@@ -1689,21 +1693,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
// To make the iteration order over the map deterministic, the comparator
// should not be using the pointer values, but rather an intrinsic property of
// the hlo. Exception: null pointer values compare less than non-null.
-//
-// Note that this cannot be used for HLO instructions across multiple modules
-// since the id of HLO instructions are only unique within each HLO module.
struct HloPtrComparator {
bool operator()(const HloInstruction* const& lhs,
- const HloInstruction* const& rhs) const {
- if (rhs == nullptr) {
- // Nothing compares less than nullptr.
- return false;
- }
- if (lhs == nullptr) {
- return true;
- }
- return lhs->unique_id() < rhs->unique_id();
- }
+ const HloInstruction* const& rhs) const;
};
template <typename ValueT>
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
index 3a1dd471c6..5bf055f3c0 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc
@@ -219,6 +219,33 @@ void PropagateLivenessToParameterCallers(
}
}
+// Makes sure that if a live instruction is within a computation used in control
+// flow operations, we mark live even other related instructions.
+void PropagateLivenessThroughControlFlow(
+ const HloInstruction* instruction,
+ HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist,
+ Workset* workset, CallGraph* call_graph) {
+ const CallGraphNode& call_graph_node =
+ call_graph->GetNode(instruction->parent());
+ if (call_graph_node.context() == CallContext::kSequential) {
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ HloInstruction* caller = callsite.instruction();
+ if (caller->opcode() == HloOpcode::kWhile) {
+ // If a live instruction is within the %while body or condition
+ // computation, mark the predicate value returned by the condition
+ // computation live as well.
+ MarkLiveAtIndex(caller->while_condition()->root_instruction(), {},
+ live_index_map, worklist, workset);
+ } else if (caller->opcode() == HloOpcode::kConditional) {
+ // If a live instruction is within the true or false branches of a
+ // conditional, we mark the predicate operand live as well.
+ MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist,
+ workset);
+ }
+ }
+ }
+}
+
} // namespace
HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module)
@@ -257,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() {
} else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist,
&workset);
- } else if (instruction->opcode() == HloOpcode::kWhile &&
- ShapeUtil::IsTuple(instruction->shape())) {
+ } else if (instruction->opcode() == HloOpcode::kWhile) {
PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist,
&workset);
- } else if (instruction->opcode() == HloOpcode::kParameter &&
- ShapeUtil::IsTuple(instruction->shape())) {
+ } else if (instruction->opcode() == HloOpcode::kParameter) {
PropagateLivenessToParameterCallers(instruction, &live_index_map_,
&worklist, &workset,
call_graph_.get());
@@ -277,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() {
MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset);
}
}
+ PropagateLivenessThroughControlFlow(instruction, &live_index_map_,
+ &worklist, &workset, call_graph_.get());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
index 01b625c29c..e0ae1173c6 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
@@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2}));
}
+TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
+TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ InnerWhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ InnerWhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ OuterWhileCondition {
+ cond_param.2 = (s32[]) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0
+ constant.5 = s32[] constant(5)
+ ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5)
+ }
+ OuterWhileBody {
+ body_param.2 = (s32[]) parameter(0)
+ get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0
+ constant.6 = s32[] constant(0)
+ tuple.2 = (s32[]) tuple(constant.6)
+ inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition,
+ body=InnerWhileBody
+ constant.7 = s32[] constant(1)
+ add.2 = s32[] add(get-tuple-element.8, constant.7)
+ ROOT rtuple = (s32[]) tuple(add.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=OuterWhileCondition,
+ body=OuterWhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ const HloLivenessAnalysis& liveness = RunLiveness(module.get());
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {}));
+ EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 5e02868eba..9964c6fdd7 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -90,7 +90,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
// A pass which schedules the HLO instructions in a module. The HloModule's
// schedule field is set to the resulting HloSchedule using
// HloModule::set_schedule.
-class HloMemoryScheduler : public HloPassInterface {
+class HloMemoryScheduler : public HloModulePass {
public:
// size_function is the function returning the number of bytes required for a
// LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not
@@ -109,7 +109,7 @@ class HloMemoryScheduler : public HloPassInterface {
// A trivial pass which clears the schedule currently set on the
// HloModule. After this pass runs HloModudle::has_schedule will return false.
-class HloDescheduler : public HloPassInterface {
+class HloDescheduler : public HloModulePass {
public:
HloDescheduler() = default;
~HloDescheduler() override = default;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index cfe906d9c5..b3949f3a6d 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names) {
+ bool uniquify_identifiers) {
if (is_entry) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
@@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal(
}
}
- if (uniquify_names) {
+ if (uniquify_identifiers) {
computation->UniquifyName(&computation_name_uniquer_);
for (auto* instruction : computation->instructions()) {
instruction->UniquifyName(&instruction_name_uniquer_);
}
+
+ // Pick unique IDs for each instruction.
+ for (auto* instruction : computation->instructions()) {
+ instruction->SetUniqueId(NewUniqueInstructionId());
+ }
+ // Set unique id to this computation.
+ CHECK_NE(computation->root_instruction()->unique_id(), -1)
+ << "Root has no valid id: " << computation->ToString();
+ computation->SetUniqueId(computation->root_instruction()->unique_id());
} else {
// Don't uniquify the names of the computation or instruction, but we must
// run the names through the uniquifiers to prevent future name collisions
- // for computations and instructions created later.
+ // for computations and instructions created later. Also, set the
+ // next_unique_id_ to the one greater than the max unique id of any
+ // instruction (or the computation) to avoid ID collisions.
computation_name_uniquer_.GetUniqueName(computation->name());
for (auto* instruction : computation->instructions()) {
instruction_name_uniquer_.GetUniqueName(instruction->name());
+ next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
+ }
+ if (next_unique_id_ < computation->unique_id() + 1) {
+ next_unique_id_ = computation->unique_id() + 1;
}
}
- // Pick unique IDs for each instruction.
- for (auto* instruction : computation->instructions()) {
- instruction->SetUniqueId(NewUniqueInstructionId());
- }
- // Set unique id to this computation.
- CHECK_NE(computation->root_instruction()->unique_id(), -1)
- << "Root has no valid id: " << computation->ToString();
- computation->SetUniqueId(computation->root_instruction()->unique_id());
-
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
@@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
HloComputation* HloModule::AddEntryComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
@@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
HloComputation* HloModule::AddEmbeddedComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
- /*uniquify_names=*/true);
+ /*uniquify_identifiers=*/true);
}
void HloModule::ReplaceComputations(
@@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
+ VLOG(2) << "CreateFromProto()";
+ XLA_VLOG_LINES(2, proto.DebugString());
+
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
@@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
module->AddComputationInternal(std::move(computation), is_entry,
- /*uniquify_names=*/false);
+ /*uniquify_identifiers=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
- // Because we didn't uniquify the names, double-check that the instruction and
- // computation names are unique from the proto.
+ // Because we didn't uniquify the names or the ids, double-check that the
+ // instruction and computation names and ids are unique from the proto.
tensorflow::gtl::FlatSet<string> computation_names;
tensorflow::gtl::FlatSet<string> instruction_names;
+ tensorflow::gtl::FlatSet<int> computation_ids;
+ tensorflow::gtl::FlatSet<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
computation_names.insert(computation->name());
+
+ TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
+ << "Computation id is not unique: " << computation->unique_id();
+ computation_ids.insert(computation->unique_id());
for (HloInstruction* instruction : computation->instructions()) {
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
<< "Instruction name is not unique: " << instruction->name();
instruction_names.insert(instruction->name());
+
+ TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
+ << "Instruction id is not unique: " << instruction->unique_id();
+ instruction_ids.insert(instruction->unique_id());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 26fd1b2438..735804e827 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -63,6 +63,7 @@ class HloModule {
// tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module.
explicit HloModule(const string& name, const HloModuleConfig& config);
+ virtual ~HloModule() {}
// Adds an entry computation to the module. A module can only have one entry
// computation. Returns a pointer to the newly added computation.
@@ -87,6 +88,7 @@ class HloModule {
const std::unordered_map<HloComputation*, HloComputation*>& replacements);
const string& name() const { return name_; }
+ void set_name(string name) { name_ = std::move(name); }
// Returns a deep copy of this module including all computations.
std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
@@ -253,9 +255,9 @@ class HloModule {
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
- bool uniquify_names);
+ bool uniquify_identifiers);
- const string name_;
+ string name_;
HloModuleConfig config_;
HloComputation* entry_computation_ = nullptr;
std::vector<std::unique_ptr<HloComputation>> computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc
index 98d20315e3..31d26cc51e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc
@@ -36,23 +36,6 @@ namespace xla {
namespace {
-bool HasSendRecv(HloComputation* computation) {
- for (auto* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kSend ||
- instruction->opcode() == HloOpcode::kSendDone ||
- instruction->opcode() == HloOpcode::kRecv ||
- instruction->opcode() == HloOpcode::kRecvDone) {
- return true;
- }
- for (auto* sub_computation : instruction->called_computations()) {
- if (HasSendRecv(sub_computation)) {
- return true;
- }
- }
- }
- return false;
-}
-
StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
bool changed = false;
for (auto* computation : module->computations()) {
@@ -67,10 +50,9 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) {
auto* while_body_root = while_body_comp->root_instruction();
if (!ShapeUtil::IsTuple(xla_while->shape()) ||
- while_body_root->opcode() != HloOpcode::kTuple ||
- HasSendRecv(while_body_comp)) {
+ while_body_root->opcode() != HloOpcode::kTuple) {
// Only run DCE on tuple-shaped while loops where body root is Tuple,
- // with no send/recv instructions.
+ // with no I/O instructions.
VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString();
continue;
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h
index 12ca2340a6..d472211d2a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce.h
+++ b/tensorflow/compiler/xla/service/hlo_module_dce.h
@@ -28,7 +28,7 @@ namespace xla {
// Sweeps through live instructions which cross computation boundaries (kWhile),
// and removes code at dead shape indices.
//
-class HloModuleDCE : public HloPassInterface {
+class HloModuleDCE : public HloModulePass {
public:
~HloModuleDCE() override {}
absl::string_view name() const override { return "hlo-module-dce"; }
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
index 363862e490..bf66cc6bc3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
@@ -367,5 +367,77 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
"while.2", 1));
}
+// Tests that a while whose body has outfeed operations is not DCE-ed.
+TEST_F(HloModuleDceTest, WhileWithOutfeed) {
+ auto module = ParseHloString(R"(
+ HloModule OutfeedLoop
+ WhileBody {
+ body_param = (s32[]) parameter(0)
+ token = token[] after-all()
+ constant.2 = s32[] constant(2)
+ outfeed_tuple = (s32[]) outfeed(constant.2, token)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[]) tuple(add)
+ }
+ WhileCondition {
+ cond_param = (s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[]) tuple(constant.3)
+ while = (s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT rtuple = () tuple()
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
+// Tests that if a loop variable is not referenced outside of a kWhile, the loop
+// variable changes are not elided within the loop body, if the condition
+// computation uses them.
+TEST_F(HloModuleDceTest, WhileWithOnlyLoopVariableBumping) {
+ auto module = ParseHloString(R"(
+ HloModule InfiniteLoop
+ WhileBody {
+ body_param = (s32[], s32[]) parameter(0)
+ get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0
+ get-tuple-element.2 = s32[] get-tuple-element(body_param), index=1
+ constant.1 = s32[] constant(1)
+ add = s32[] add(get-tuple-element.1, constant.1)
+ ROOT tuple = (s32[], s32[]) tuple(add, get-tuple-element.2)
+ }
+ WhileCondition {
+ cond_param = (s32[], s32[]) parameter(0)
+ get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0
+ constant.2 = s32[] constant(10)
+ ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2)
+ }
+ ENTRY SimpleLoop {
+ p0 = (s32[]) parameter(0)
+ get-tuple-element.5 = s32[] get-tuple-element(p0), index=0
+ constant.3 = s32[] constant(0)
+ tuple.1 = (s32[], s32[]) tuple(constant.3, get-tuple-element.5)
+ while = (s32[], s32[]) while(tuple.1), condition=WhileCondition,
+ body=WhileBody
+ ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=1
+ })")
+ .ValueOrDie();
+
+ HloModuleDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(),
+ "while", 0));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.cc b/tensorflow/compiler/xla/service/hlo_module_group.cc
new file mode 100644
index 0000000000..f9b56ef464
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.cc
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+namespace xla {
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+ std::unique_ptr<HloModule> module)
+ : name_(name) {
+ push_back(std::move(module));
+}
+
+HloModuleGroup::HloModuleGroup(absl::string_view name,
+ absl::Span<std::unique_ptr<HloModule>> modules)
+ : name_(name) {
+ for (auto& module : modules) {
+ push_back(std::move(module));
+ }
+}
+
+std::vector<std::unique_ptr<HloModule>> HloModuleGroup::ConsumeModules() {
+ std::vector<std::unique_ptr<HloModule>> ret_modules = std::move(modules_);
+
+ // Clear everything so the object state is in a known (empty) state.
+ modules_.clear();
+ module_ptrs_.clear();
+ return ret_modules;
+}
+
+string HloModuleGroup::ToString() const {
+ std::ostringstream s;
+ s << "HloModuleGroup " << name() << "\n\n";
+ for (const HloModule* module : modules()) {
+ s << module->ToString() << "\n";
+ }
+ return s.str();
+}
+
+HloModuleGroupProto HloModuleGroup::ToProto() const {
+ HloModuleGroupProto proto;
+ proto.set_name(name());
+ for (const HloModule* module : modules()) {
+ *proto.add_hlo_modules() = module->ToProto();
+ }
+ return proto;
+}
+
+/* static */ StatusOr<HloModuleGroup> HloModuleGroup::CreateFromProto(
+ const HloModuleGroupProto& proto,
+ absl::Span<const HloModuleConfig> module_configs) {
+ TF_RET_CHECK(!proto.name().empty()) << "Module group name cannot be empty";
+ TF_RET_CHECK(proto.hlo_modules_size() > 0)
+ << "Module group must have at least one HLO module";
+ TF_RET_CHECK(proto.hlo_modules_size() == module_configs.size());
+
+ std::vector<std::unique_ptr<HloModule>> modules;
+ for (int i = 0; i < proto.hlo_modules_size(); ++i) {
+ const HloModuleProto& module_proto = proto.hlo_modules(i);
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(module_proto, module_configs[i]));
+ modules.push_back(std::move(module));
+ }
+
+ return HloModuleGroup(proto.name(), absl::MakeSpan(modules));
+}
+
+void HloModuleGroup::push_back(std::unique_ptr<HloModule> module) {
+ modules_.push_back(std::move(module));
+ module_ptrs_.push_back(modules_.back().get());
+}
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group) {
+ out << group.ToString();
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_group.h b/tensorflow/compiler/xla/service/hlo_module_group.h
new file mode 100644
index 0000000000..7338be8b9c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
+
+#include <iosfwd>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+namespace xla {
+
+// An abstraction representing a ordered set of HLO module built to run
+// concurrently across different devices.
+class HloModuleGroup {
+ public:
+ // Construct an empty module group.
+ explicit HloModuleGroup(absl::string_view name) : name_(name) {}
+
+ // Construct a module group containing a single module.
+ HloModuleGroup(absl::string_view name, std::unique_ptr<HloModule> module);
+
+ // Construct a module group containing any number of modules.
+ HloModuleGroup(absl::string_view name,
+ absl::Span<std::unique_ptr<HloModule>> modules);
+
+ // Returns the modules contained in the group.
+ const std::vector<HloModule*>& modules() const { return module_ptrs_; }
+
+ // Returns a module at a particular index.
+ HloModule& module(int index) const { return *module_ptrs_.at(index); }
+
+ // Add a module to the back of vector of modules in the group.
+ void push_back(std::unique_ptr<HloModule> module);
+
+ // Moves all modules from the group into the returned vector. After this
+ // method runs, the module group will be empty.
+ std::vector<std::unique_ptr<HloModule>> ConsumeModules();
+
+ string name() const { return name_; }
+ string ToString() const;
+
+ // Serialize the module group to/from a proto.
+ HloModuleGroupProto ToProto() const;
+ static StatusOr<HloModuleGroup> CreateFromProto(
+ const HloModuleGroupProto& proto,
+ absl::Span<const HloModuleConfig> module_configs);
+
+ private:
+ string name_;
+
+ // Vector of modules as std::unique_ptrs.
+ std::vector<std::unique_ptr<HloModule>> modules_;
+
+ // Vector of modules as normal pointers. This vector is kept in sync with
+ // modules_ as modules are added to the group with push_back.
+ std::vector<HloModule*> module_ptrs_;
+};
+
+std::ostream& operator<<(std::ostream& out, const HloModuleGroup& group);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 9c01862a4b..83352ef91b 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -392,22 +392,28 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
if (!ContainsKey(companion_set_index_, instruction1) &&
!ContainsKey(companion_set_index_, instruction2)) {
companion_sets_.push_back(
- absl::make_unique<std::unordered_set<HloInstruction*>>());
+ absl::make_unique<std::vector<HloInstruction*>>());
auto companion_set = companion_sets_.back().get();
- companion_set->insert(instruction1);
- companion_set->insert(instruction2);
+ companion_set->push_back(instruction1);
+ companion_set->push_back(instruction2);
companion_set_index_[instruction1] = companion_sets_.size() - 1;
companion_set_index_[instruction2] = companion_sets_.size() - 1;
} else if (!ContainsKey(companion_set_index_, instruction1)) {
- companion_sets_[companion_set_index_[instruction2]]->insert(instruction1);
+ companion_sets_[companion_set_index_[instruction2]]->push_back(
+ instruction1);
companion_set_index_[instruction1] = companion_set_index_[instruction2];
} else if (!ContainsKey(companion_set_index_, instruction2)) {
- companion_sets_[companion_set_index_[instruction1]]->insert(instruction2);
+ companion_sets_[companion_set_index_[instruction1]]->push_back(
+ instruction2);
companion_set_index_[instruction2] = companion_set_index_[instruction1];
} else if (companion_set_index_[instruction1] !=
companion_set_index_[instruction2]) {
- companion_sets_[companion_set_index_[instruction1]]->insert(
- Companions(instruction2).begin(), Companions(instruction2).end());
+ // At any point while building the companion sets, each instruction belongs
+ // to at most 1 companion set, so the union of two companion sets is
+ // concatenating two disjoint sets.
+ absl::c_copy(Companions(instruction2),
+ std::back_inserter(
+ *companion_sets_[companion_set_index_[instruction1]]));
int64 index_to_remove = companion_set_index_[instruction2];
for (HloInstruction* hlo : Companions(instruction2)) {
companion_set_index_[hlo] = companion_set_index_[instruction1];
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 768b0c7eb3..278d94cdd3 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -169,14 +169,14 @@ class HloModuleGroupMetadata {
// Returns the companion instructions for the given instruction.
//
// Precondition: IsCompanionWhile(instruction) is true.
- const std::unordered_set<HloInstruction*>& Companions(
+ const std::vector<HloInstruction*>& Companions(
const HloInstruction* instruction) const {
CHECK_EQ(companion_set_index_.count(instruction), 1);
return companion_set(companion_set_index_.at(instruction));
}
// Returns the companion set at the given index.
- const std::unordered_set<HloInstruction*>& companion_set(int64 index) const {
+ const std::vector<HloInstruction*>& companion_set(int64 index) const {
CHECK_LT(index, companion_sets_.size());
return *companion_sets_[index];
}
@@ -187,7 +187,7 @@ class HloModuleGroupMetadata {
}
// Returns the list of all companion sets in the HLO module group.
- const std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>&
+ const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>&
companion_sets() const {
return companion_sets_;
}
@@ -247,8 +247,7 @@ class HloModuleGroupMetadata {
void DumpCollectedStats() const;
// List of all companion instructions sets in the module.
- std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>
- companion_sets_;
+ std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
// Map from each companion while instruction to the index into companion_set_.
tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
new file mode 100644
index 0000000000..b7b12cb72b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc
@@ -0,0 +1,206 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
+
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+
+namespace {
+
+namespace op = ::xla::testing::opcode_matchers;
+
+class HloModuleGroupTest : public HloTestBase {
+ protected:
+ HloModuleGroupTest() = default;
+};
+
+TEST_F(HloModuleGroupTest, SingleModule) {
+ const string text = R"(
+HloModule simple_module
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+ HloModuleGroup group(TestName(), std::move(module));
+
+ EXPECT_EQ(group.modules().size(), 1);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+ HloModuleGroup::CreateFromProto(
+ group.ToProto(), {group.module(0).config()}));
+ EXPECT_EQ(group_copy.modules().size(), 1);
+ EXPECT_THAT(
+ group_copy.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+
+ std::vector<std::unique_ptr<HloModule>> modules = group.ConsumeModules();
+ EXPECT_EQ(modules.size(), 1);
+ EXPECT_EQ(group.modules().size(), 0);
+}
+
+TEST_F(HloModuleGroupTest, MultipleModules) {
+ const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+ ROOT %a = f32[] parameter(0)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+ ParseHloString(text_0));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+ ParseHloString(text_1));
+ std::vector<std::unique_ptr<HloModule>> modules;
+ modules.push_back(std::move(module_0));
+ modules.push_back(std::move(module_1));
+ HloModuleGroup group(TestName(), absl::MakeSpan(modules));
+ EXPECT_EQ(group.modules().size(), 2);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+ EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter()));
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup group_copy,
+ HloModuleGroup::CreateFromProto(
+ group.ToProto(), {group.module(0).config(),
+ group.module(1).config()}));
+ EXPECT_EQ(group_copy.modules().size(), 2);
+}
+
+TEST_F(HloModuleGroupTest, BuildModuleGroupByPushBack) {
+ const string text_0 = R"(
+HloModule module0
+
+ENTRY %entry (x: f32[], y: f32[]) -> f32[] {
+ %x = f32[] parameter(0)
+ %y = f32[] parameter(1)
+ ROOT %add = f32[] add(%x, %y)
+}
+)";
+ const string text_1 = R"(
+HloModule module1
+
+ENTRY %entry (a: f32[]) -> f32[] {
+ ROOT %a = f32[] parameter(0)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_0,
+ ParseHloString(text_0));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module_1,
+ ParseHloString(text_1));
+ HloModuleGroup group(TestName());
+ group.push_back(std::move(module_0));
+ group.push_back(std::move(module_1));
+
+ EXPECT_EQ(group.modules().size(), 2);
+ EXPECT_THAT(
+ group.module(0).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Add()));
+ EXPECT_THAT(group.module(1).entry_computation()->instructions(),
+ ::testing::ElementsAre(op::Parameter()));
+}
+
+// Tests that the order of companion instructions in the companion set doesn't
+// change across runs.
+TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) {
+ // A simple while loop template for core i sending to core i+1.
+ constexpr char text[] = R"(
+HloModule module_%d
+
+while_cond {
+ ROOT p = pred[] constant(true)
+}
+
+while_body {
+ param = s32[] parameter(0)
+ token.s = token[] after-all()
+ token.r = token[] after-all()
+ send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d
+ send-done = token[] send-done(send), channel_id=%d
+ recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d
+ ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d
+}
+
+ENTRY entry {
+ while_init = s32[] constant(1)
+ ROOT while = s32[] while(while_init), condition=while_cond, body=while_body
+}
+)";
+
+ // Try creating the module and the metadata kTrialCount times and check the
+ // companion instructions remain in the same order.
+ const int64 kTrialCount = 5;
+ const int64 kDeviceCount = 10;
+ std::vector<int64> companion_order;
+
+ for (int64 t = 0; t < kTrialCount; ++t) {
+ HloModuleGroup group(TestName());
+ for (int64 i = 0; i < kDeviceCount; ++i) {
+ const int64 send_channel = i;
+ const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1;
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(absl::StrFormat(text, i, send_channel, send_channel,
+ recv_channel, recv_channel)));
+ group.push_back(std::move(module));
+ }
+ ASSERT_EQ(group.modules().size(), kDeviceCount);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto metadata,
+ HloModuleGroupMetadata::Build(group.modules()));
+ ASSERT_EQ(metadata->companion_sets().size(), 1);
+
+ std::vector<int64> module_ids;
+ for (HloInstruction* companion : *metadata->companion_sets()[0]) {
+ module_ids.push_back(metadata->GetModuleId(companion->GetModule()));
+ }
+
+ if (t == 0) {
+ companion_order = module_ids;
+ } else {
+ EXPECT_TRUE(absl::c_equal(companion_order, module_ids));
+ }
+ }
+}
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
index 400bd4d947..39f38b417a 100644
--- a/tensorflow/compiler/xla/service/hlo_module_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -20,12 +20,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/test.h"
@@ -253,6 +253,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
+TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
+ // Verify that serializing then deserializing an HLO proto preserves the
+ // unique IDs of the instruction and module.
+ const string text =
+ R"(HloModule ReduceR3ToR2_module
+
+add_F32.v3 {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY ReduceR3ToR2.v3 {
+ input = f32[8,16,256]{2,1,0} parameter(0)
+ constant = f32[] constant(0)
+ ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(text));
+
+ // Perform various transformations on the graph:
+ //
+ // * clone the reduction function
+ // * replace use of reduction function with the clone.
+ // * add a random instruction to the entry computation.
+ //
+ // This will create instruction and computation IDs which are interesting:
+ // not consecutive and not densely packed.
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* root = entry->root_instruction();
+ HloComputation* reduction = root->to_apply();
+ HloComputation* reduction_clone =
+ module->AddEmbeddedComputation(reduction->Clone());
+ root->set_to_apply(reduction_clone);
+ TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
+ HloInstruction* negate = entry->AddInstruction(
+ HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
+ entry->set_root_instruction(negate);
+
+ // Schedule the transformed module, this verifies that the serialized schedule
+ // is robust against non-consecutive IDs as well (b/114712358).
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape());
+ };
+ HloMemoryScheduler scheduler(size_fn);
+ TF_ASSERT_OK(scheduler.Run(module.get()).status());
+ ASSERT_TRUE(module->has_schedule());
+
+ // Serialize and deserialize and verify that the instruction and computations
+ // unique ids are the same.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module_copy,
+ HloModule::CreateFromProto(module->ToProto(), module->config()));
+
+ // The module IDs should *not* be the same because module ids must be globally
+ // unique.
+ EXPECT_NE(module->unique_id(), module_copy->unique_id());
+
+ // Verify that the computations and instructions all have the same unique id.
+ auto computation_copy_it = module_copy->computations().begin();
+ for (const HloComputation* computation_orig : module->computations()) {
+ const HloComputation* computation_copy = *computation_copy_it++;
+ EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original computation %s != ID of deserialized "
+ "computation %s: %d != %d",
+ computation_orig->name(), computation_copy->name(),
+ computation_orig->unique_id(), computation_copy->unique_id());
+
+ auto instruction_copy_it = computation_copy->instructions().begin();
+ for (const HloInstruction* instruction_orig :
+ computation_orig->instructions()) {
+ const HloInstruction* instruction_copy = *instruction_copy_it++;
+ EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
+ << absl::StrFormat(
+ "ID of original instruction %s != ID of deserialized "
+ "instruction %s: %d != %d",
+ instruction_orig->name(), instruction_copy->name(),
+ instruction_orig->unique_id(), instruction_copy->unique_id());
+ }
+ }
+
+ // Verify that the next unique ID which the module would have handed out is
+ // greater than the unique id of any instruction.
+ int next_id = module_copy->NewUniqueInstructionId();
+ for (const HloComputation* computation : module_copy->computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ EXPECT_GT(next_id, instruction->unique_id());
+ }
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 11caa89c54..37197b273b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -64,14 +64,11 @@ class HloParser {
public:
using LocTy = HloLexer::LocTy;
- explicit HloParser(absl::string_view str, const HloModuleConfig& config)
- : lexer_(str), config_(config) {}
+ explicit HloParser(absl::string_view str) : lexer_(str) {}
- // Runs the parser. Returns false if an error occurred.
- bool Run();
-
- // Returns the parsed HloModule.
- std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
+ // Runs the parser and constructs the resulting HLO in the given (empty)
+ // HloModule. Returns false if an error occurred.
+ bool Run(HloModule* module);
// Returns the error information.
string GetError() const { return StrJoin(error_, "\n"); }
@@ -98,8 +95,8 @@ class HloParser {
const string& name, const optional<Shape>& shape = nullopt);
// ParseXXX returns false if an error occurred.
- bool ParseHloModule();
- bool ParseComputations();
+ bool ParseHloModule(HloModule* module);
+ bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
@@ -293,9 +290,7 @@ class HloParser {
computation_pool_;
HloLexer lexer_;
- std::unique_ptr<HloModule> module_;
std::vector<std::unique_ptr<HloComputation>> computations_;
- const HloModuleConfig config_;
std::vector<string> error_;
// Function that gets invoked when we try to resolve an instruction
@@ -349,9 +344,9 @@ bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
-bool HloParser::Run() {
+bool HloParser::Run(HloModule* module) {
lexer_.Lex();
- return ParseHloModule();
+ return ParseHloModule(module);
}
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
@@ -366,7 +361,7 @@ std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
}
// ::= 'HloModule' name computations
-bool HloParser::ParseHloModule() {
+bool HloParser::ParseHloModule(HloModule* module) {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
@@ -385,22 +380,20 @@ bool HloParser::ParseHloModule() {
return false;
}
- module_ = absl::make_unique<HloModule>(name, config_);
-
- if (!ParseComputations()) {
+ module->set_name(name);
+ if (!ParseComputations(module)) {
return false;
}
if (is_scheduled.has_value() && *is_scheduled) {
- TF_CHECK_OK(
- module_->set_schedule(ScheduleFromInstructionOrder(module_.get())));
+ TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module)));
}
return true;
}
// computations ::= (computation)+
-bool HloParser::ParseComputations() {
+bool HloParser::ParseComputations(HloModule* module) {
HloComputation* entry_computation = nullptr;
do {
if (!ParseComputation(&entry_computation)) {
@@ -416,21 +409,20 @@ bool HloParser::ParseComputations() {
if ((entry_computation != nullptr &&
computations_[i].get() != entry_computation) ||
(entry_computation == nullptr && i != computations_.size() - 1)) {
- module_->AddEmbeddedComputation(std::move(computations_[i]));
+ module->AddEmbeddedComputation(std::move(computations_[i]));
continue;
}
- auto computation =
- module_->AddEntryComputation(std::move(computations_[i]));
+ auto computation = module->AddEntryComputation(std::move(computations_[i]));
// The parameters and result layouts were set to default layout. Here we
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_parameter_layout(p)
->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ TF_CHECK_OK(module->mutable_entry_computation_layout()
->mutable_result_layout()
->CopyLayoutFromShape(result_shape));
}
@@ -3247,53 +3239,62 @@ Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config) {
- HloParser parser(str, config);
- if (!parser.Run()) {
+ auto module = absl::make_unique<HloModule>(/*name=*/"", config);
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
return InvalidArgument("Syntax error:\n%s", parser.GetError());
}
- return parser.ConsumeHloModule();
+ return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
- HloModuleConfig config;
- return ParseHloString(str, config);
+ auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig());
+ HloParser parser(str);
+ if (!parser.Run(module.get())) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return std::move(module);
+}
+
+Status ParseHloString(absl::string_view str, HloModule* module) {
+ TF_RET_CHECK(module->computation_count() == 0);
+ HloParser parser(str);
+ if (!parser.Run(module)) {
+ return InvalidArgument("Syntax error:\n%s", parser.GetError());
+ }
+ return Status::OK();
}
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
auto builder = absl::make_unique<HloComputation::Builder>(string(name));
string root_name;
TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
std::unique_ptr<HloComputation> computation = builder->Build();
- auto module = absl::make_unique<HloModule>(string(name), config);
+ auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
module->AddEntryComputation(std::move(computation));
return std::move(module);
}
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseShardingOnly();
}
StatusOr<Window> ParseWindow(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseWindowOnly();
}
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParseConvolutionDimensionNumbersOnly();
}
StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) {
- HloModuleConfig config;
- HloParser parser(str, config);
+ HloParser parser(str);
return parser.ParsePaddingConfigOnly();
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 1882a184da..3696035514 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -30,18 +30,23 @@ namespace xla {
// For details about the syntax accepted by this parser, see
// g3doc/hlo_parser.md.
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with the given config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with the given config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config);
+// Given a string in the HloModule::ToString() format, parses the string and
+// builds the HloModule in place at the given module pointer. 'module' must
+// point to an empty module (no computations).
+Status ParseHloString(absl::string_view str, HloModule* module);
+
// Parses the text for a single HLO operation into an HLO module with a function
// that runs that operation (with the same parameters) as its entry computation.
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name = "single_op");
-// The api of the hlo parser. Given a string in the HloModule::ToString()
-// format, parses the string and creates a HloModule with default config.
+// Given a string in the HloModule::ToString() format, parses the string and
+// creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h
index f1ad0f9b01..fdaac34386 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_interface.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_group.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -25,15 +26,45 @@ limitations under the License.
namespace xla {
// Base class for HLO passes. These are used with the HloPassPipeline to
-// organize a sequence of passes.
+// organize a sequence of passes. An HLO pass should not extend this class
+// directly; it should extend HloModulePass or HloModuleGroupPass.
class HloPassInterface {
public:
virtual ~HloPassInterface() = default;
virtual absl::string_view name() const = 0;
- // Run the pass on the given HLO module. Return whether it modified the
+ // Run the pass on the given HLO module. Returns whether it modified the
// module.
virtual StatusOr<bool> Run(HloModule* module) = 0;
+
+ // Run the pass on the given HLO module group. Returns whether it modified the
+ // module group. Ideally, the module group variant would be named "Run" as
+ // well, but C++ does not handle overloaded virtual methods well.
+ virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0;
+};
+
+// Base class for passes which are module-scoped.
+class HloModulePass : public HloPassInterface {
+ public:
+ // Runs the pass on a module group by iterating through each module in the
+ // group.
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+ bool changed = false;
+ for (HloModule* module : module_group->modules()) {
+ TF_ASSIGN_OR_RETURN(bool module_changed, Run(module));
+ changed |= module_changed;
+ }
+ return changed;
+ };
+};
+
+// Base class for passes which are module-group scoped. These passes cannot run
+// on an HLO module.
+class HloModuleGroupPass : public HloPassInterface {
+ public:
+ StatusOr<bool> Run(HloModule* module) override {
+ return InternalError("Module group pass cannot be run on a module");
+ }
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 6e4ed0de62..8c2f928ca1 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <functional>
-#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@@ -29,108 +28,128 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace xla {
-namespace {
-using absl::StrAppend;
-using absl::StrCat;
-
-void DumpModuleGraph(const HloModule& module, const string& message) {
- hlo_graph_dumper::MaybeDumpHloModule(module, message);
- VLOG(3) << "HLO " << message << ":";
- XLA_VLOG_LINES(3, module.ToString());
+template <typename HloT>
+Status HloPassPipeline::RunInvariantCheckers(
+ HloT* hlo, absl::string_view after_pass_name) {
+ for (auto& invariant_checker : invariant_checkers_) {
+ VLOG(1) << " Invariant checker " << invariant_checker->name();
+ StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo);
+ VLOG(1) << " Invariant checker done " << invariant_checker->name();
+ if (!changed_status.ok()) {
+ VLOG(2) << "Failed invariant check:";
+ XLA_VLOG_LINES(2, hlo->ToString());
+ return Status(changed_status.status().code(),
+ absl::StrCat(changed_status.status().error_message(),
+ "\n\nFailed after ", after_pass_name));
+ }
+ TF_RET_CHECK(!changed_status.ValueOrDie())
+ << "invariant checkers must not change the graph";
+ }
+ return Status::OK();
}
-void DumpModuleProto(const HloModule& module, const string& dump_to,
- const string& pipeline_name, const string& pass_name) {
- static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
- static auto* const module_id_to_pass_number =
- new tensorflow::gtl::FlatMap<int64, int64>();
-
- tensorflow::mutex_lock lock(mu);
- const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
+template <typename HloT>
+StatusOr<bool> HloPassPipeline::RunPassesInternal(
+ HloT* hlo, absl::Span<HloPassInterface* const> passes) {
+ string last_pass_name = "pipeline-start";
+ TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name));
+ bool changed = false;
+ for (HloPassInterface* pass : passes) {
+ VLOG(1) << " HLO pass " << pass->name();
+ MaybeDumpHlo(*hlo,
+ /*after_pass_name=*/last_pass_name,
+ /*before_pass_name=*/pass->name());
+ TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo));
+ changed |= pass_changed;
+ TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name()));
+ last_pass_name = string(pass->name());
+ }
+ MaybeDumpHlo(*hlo,
+ /*after_pass_name=*/last_pass_name,
+ /*before_pass_name=*/"pipeline-end");
+ return changed;
+}
- const string mod_name = SanitizeFileName(
- absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
- pass_number, pipeline_name, pass_name));
+std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
+ const DebugOptions& debug_options) {
+ auto repeated_field = debug_options.xla_disable_hlo_passes();
+ tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
+ repeated_field.end());
+ if (!disabled_pass_names.empty()) {
+ VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
+ << absl::StrJoin(disabled_pass_names, ", ");
+ }
- TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module),
- dump_to, mod_name));
+ std::vector<HloPassInterface*> enabled_passes;
+ for (auto& pass : passes_) {
+ if (disabled_pass_names.count(string(pass->name())) == 0) {
+ enabled_passes.push_back(pass.get());
+ }
+ }
+ return enabled_passes;
}
-} // namespace
-StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
- run_called_ = true;
+void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name) {
+ const string& proto_dump_path =
+ module.config().debug_options().xla_dump_per_pass_hlo_proto_to();
+ if (!proto_dump_path.empty()) {
+ static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
+ static auto* const module_id_to_pass_number =
+ new tensorflow::gtl::FlatMap<int64, int64>();
+
+ tensorflow::mutex_lock lock(mu);
+ const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
+
+ const string filename = SanitizeFileName(
+ absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(),
+ pass_number, name(), after_pass_name));
+
+ TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(
+ MakeHloProto(module), proto_dump_path, filename));
+ }
- VLOG(1) << "Running HLO pass pipeline " << name();
+ const string message =
+ StrCat("after ", after_pass_name, ", before ", before_pass_name);
+ hlo_graph_dumper::MaybeDumpHloModule(module, message);
+ VLOG(3) << "HLO " << message << ":";
+ XLA_VLOG_LINES(3, module.ToString());
+}
- auto repeated_field =
- module->config().debug_options().xla_disable_hlo_passes();
- tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(),
- repeated_field.end());
- if (!disabled_passes.empty()) {
- VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
- << absl::StrJoin(disabled_passes, ", ");
+void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name) {
+ for (const HloModule* module : module_group.modules()) {
+ MaybeDumpHlo(*module, after_pass_name, before_pass_name);
}
+}
- auto run_invariant_checkers = [this,
- module](const string& message) -> Status {
- for (auto& invariant_checker : invariant_checkers_) {
- VLOG(1) << " Invariant checker " << invariant_checker->name();
- StatusOr<bool> changed_status = invariant_checker->Run(module);
- VLOG(1) << " Invariant checker done " << invariant_checker->name();
- if (!changed_status.ok()) {
- VLOG(2) << "Module failed invariant check:";
- XLA_VLOG_LINES(2, module->ToString());
- return Status(changed_status.status().code(),
- StrCat(changed_status.status().error_message(),
- "\n\nFailed ", message));
- }
- TF_RET_CHECK(!changed_status.ValueOrDie())
- << "invariant checkers must not change the graph";
- }
- return Status::OK();
- };
+StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
+ run_called_ = true;
- string prefix = StrCat(name(), ": pipeline start");
- bool changed = false;
- string message;
- TF_RETURN_IF_ERROR(
- run_invariant_checkers(StrCat("before running pipeline: ", name())));
- const string xla_dump_per_pass_hlo_proto_to =
- module->config().debug_options().xla_dump_per_pass_hlo_proto_to();
- if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
- "pipeline_start");
- }
+ VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": "
+ << name();
- for (auto& pass : passes_) {
- if (disabled_passes.count(string(pass->name())) > 0) {
- VLOG(1) << " Skipping HLO pass " << pass->name()
- << ", disabled by --xla_disable_hlo_passes";
- continue;
- }
+ return RunPassesInternal(module,
+ GetEnabledPasses(module->config().debug_options()));
+}
- VLOG(1) << " HLO pass " << pass->name();
+StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) {
+ run_called_ = true;
- // Emit label containing: "after foo-pass, before bar-pass".
- message.clear();
- StrAppend(&message, prefix, ", before ", pass->name());
- DumpModuleGraph(*module, message);
-
- TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
- TF_RETURN_IF_ERROR(
- run_invariant_checkers(StrCat("after running pass: ", pass->name())));
- if (!xla_dump_per_pass_hlo_proto_to.empty()) {
- DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()),
- string(pass->name()));
- }
+ VLOG(1) << "Running HLO pass pipeline on module group "
+ << module_group->name() << ": " << name();
- changed |= changed_this_pass;
- prefix.clear();
- StrAppend(&prefix, name(), ": after ", pass->name());
+ if (module_group->modules().empty()) {
+ VLOG(1) << "Module group is empty. Nothing to do.";
+ return false;
}
- DumpModuleGraph(*module, prefix + ", pipeline end");
- return changed;
+
+ return RunPassesInternal(
+ module_group,
+ GetEnabledPasses(module_group->module(0).config().debug_options()));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
index 1d41a4dac1..09e7033ea4 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/memory/memory.h"
+#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface {
return *pass;
}
- // Run all passes on the given HLO module.
StatusOr<bool> Run(HloModule* module) override;
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override;
private:
+ // Returns the set of passes which are enabled. DebugOptions can selectively
+ // disable passes via --xla_disable_hlo_passes flag.
+ std::vector<HloPassInterface*> GetEnabledPasses(
+ const DebugOptions& debug_options);
+
+ // Maybe dumps the given module or module group depending on flag values
+ // contained in DebugOptions of module config.
+ void MaybeDumpHlo(const HloModuleGroup& module_group,
+ absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+ void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name,
+ absl::string_view before_pass_name);
+
+ // Runs the invariant checker on the given HLO. HloT can be either HloModule
+ // or HloModuleGroup.
+ template <typename HloT>
+ Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name);
+
+ // Helper which runs the given pass on the given HLO. HloT can be either
+ // HloModule or HloModuleGroup.
+ template <typename HloT>
+ StatusOr<bool> RunPassesInternal(HloT* hlo,
+ absl::Span<HloPassInterface* const> passes);
+
+ // Helpers which run the given passes on the given HLO construct. These
+ // helpers enable templating of the core of the pipeline logic by providing
+ // HloModule and HloModuleGroup specific methods with the same name.
+ static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) {
+ return pass->Run(module);
+ }
+ static StatusOr<bool> RunHelper(HloPassInterface* pass,
+ HloModuleGroup* module_group) {
+ return pass->RunOnModuleGroup(module_group);
+ }
+
const string name_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
new file mode 100644
index 0000000000..ee8cb12b23
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc
@@ -0,0 +1,259 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+namespace {
+
+class HloPassPipelineTest : public HloVerifiedTestBase {
+ protected:
+ StatusOr<HloModuleGroup> ParseModuleGroup(
+ absl::Span<const string> hlo_strings) {
+ HloModuleGroup group(TestName());
+ for (const string& hlo_string : hlo_strings) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ group.push_back(std::move(module));
+ }
+ return std::move(group);
+ }
+};
+
+// A module pass which renames instructions named 'foo' to 'bar'.
+class FooToBarModulePass : public HloModulePass {
+ absl::string_view name() const override { return "foo2bar"; }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "foo") {
+ instruction->SetAndSanitizeName("bar");
+ changed = true;
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+// A module group pass which renames instructions named 'baz' to 'qux'.
+class BazToQuxModuleGroupPass : public HloModuleGroupPass {
+ absl::string_view name() const override { return "baz2qux"; }
+
+ StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override {
+ bool changed = false;
+ for (HloModule* module : module_group->modules()) {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "baz") {
+ instruction->SetAndSanitizeName("qux");
+ changed = true;
+ }
+ }
+ }
+ }
+ return changed;
+ }
+};
+
+// An invariant checker pass which returns an error if there exists an
+// instruction named 'bar'.
+class BarBlowerUpper : public HloModulePass {
+ absl::string_view name() const override { return "bar-blower-upper"; }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->name() == "bar") {
+ return InternalError("Module has instruction named bar");
+ }
+ }
+ }
+ return false;
+ }
+};
+
+TEST_F(HloPassPipelineTest, ModulePassChanged) {
+ // Test an HLO module pass which changes a module.
+ const string module_str = R"(
+HloModule ModulePassChanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<FooToBarModulePass>();
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->name(), "foo");
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_TRUE(changed);
+ EXPECT_EQ(root->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
+ // Test an HLO module pass which does not change a module.
+ const string module_str = R"(
+HloModule ModulePassUnchanged
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT blahblah = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<FooToBarModulePass>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+}
+
+TEST_F(HloPassPipelineTest, MixedPipeline) {
+ // Test a pipeline with both a module pass and a module group pass.
+ const string module_0_str = R"(
+HloModule MixedPipeline.1
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT baz = f32[] multiply(a, b)
+}
+)";
+ const string module_1_str = R"(
+HloModule MixedPipeline.0
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
+ ParseModuleGroup({module_0_str, module_1_str}));
+
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<BazToQuxModuleGroupPass>();
+ pipeline.AddPass<FooToBarModulePass>();
+
+ HloInstruction* root0 =
+ module_group.module(0).entry_computation()->root_instruction();
+ HloInstruction* root1 =
+ module_group.module(1).entry_computation()->root_instruction();
+ EXPECT_EQ(root0->name(), "baz");
+ EXPECT_EQ(root1->name(), "foo");
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed,
+ pipeline.RunOnModuleGroup(&module_group));
+ EXPECT_TRUE(changed);
+
+ EXPECT_EQ(root0->name(), "qux");
+ EXPECT_EQ(root1->name(), "bar");
+}
+
+TEST_F(HloPassPipelineTest, InvariantChecker) {
+ const string module_str = R"(
+HloModule InvariantChecker
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ {
+ // Run a pipeline with just the invariant checker. It should not fail
+ // because there is no 'bar' instruction in the module.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+ TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
+ EXPECT_FALSE(changed);
+ }
+
+ {
+ // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
+ // which fails if there is an instruction named 'bar'.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+ pipeline.AddPass<FooToBarModulePass>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Module has instruction named bar"));
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Failed after foo2bar"));
+ }
+
+ {
+ // Run the invariant-checker only pipeline again. It should fail this time.
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddInvariantChecker<BarBlowerUpper>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Module has instruction named bar"));
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Failed after pipeline-start"));
+ }
+}
+
+TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
+ // Running a module group pass on a module should produce an error.
+ const string module_str = R"(
+HloModule ModuleGroupPassOnModule
+
+ENTRY main {
+ a = f32[] parameter(0)
+ b = f32[] parameter(1)
+ ROOT foo = f32[] multiply(a, b)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ HloPassPipeline pipeline(TestName());
+ pipeline.AddPass<BazToQuxModuleGroupPass>();
+
+ Status status = pipeline.Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Module group pass cannot be run on a module"));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index bd6dd79b67..a438671936 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -1198,6 +1198,12 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
<< HumanReadableNumBytes(memory_limit_bytes_);
XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
+ // Initialize pass object state.
+ computation_peak_memory_.clear();
+ rematerialized_computations_.clear();
+ instructions_rematerialized_ = 0;
+ net_instructions_added_ = 0;
+
TF_RET_CHECK(module->has_schedule());
TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index e2aaf18b3e..7330d73c09 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -33,7 +33,7 @@ namespace xla {
// CSE will undo the effects of this optimization and should not be run after
// this pass. In general, this pass should be run very late, immediately before
// code generation.
-class HloRematerialization : public HloPassInterface {
+class HloRematerialization : public HloModulePass {
public:
using ShapeSizeFunction = std::function<int64(const Shape&)>;
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
index d1cf644f82..fa34bddde1 100644
--- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -22,7 +22,7 @@ namespace xla {
// Unify subcomputations of a `HloModule`: if any computations are equal, choose
// one arbitrarily to use and delete the others.
-class HloSubcomputationUnification : public HloPassInterface {
+class HloSubcomputationUnification : public HloModulePass {
public:
absl::string_view name() const override {
return "subcomputation-unification";
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 773fc7d225..8549487702 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -131,6 +131,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
CHECK_LE(operand_number, 2);
return operand_number == 0 || index.empty();
+ case HloOpcode::kDomain:
case HloOpcode::kTuple:
// These instructions always pass through their operands transparently.
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 50f39cbcb5..6eb6658904 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -1057,6 +1057,7 @@ Status VerifySendsAndRecvs(const HloModule& module) {
} // namespace
StatusOr<bool> HloVerifier::Run(HloModule* module) {
+ TF_RET_CHECK(!module->name().empty());
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 42e3027bf1..0cde4a31af 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -151,7 +151,7 @@ class ShapeVerifier : public DfsHloVisitor {
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
-class HloVerifier : public HloPassInterface {
+class HloVerifier : public HloModulePass {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
index 85bb4a8b24..9c48b7db61 100644
--- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
+++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h
@@ -25,7 +25,7 @@ namespace xla {
// Pass which replaces all implicit broadcasts with their equivalent sequence of
// explicit broadcast and reshape instructions.
-class ImplicitBroadcastRemover : public HloPassInterface {
+class ImplicitBroadcastRemover : public HloModulePass {
public:
ImplicitBroadcastRemover() {}
~ImplicitBroadcastRemover() override {}
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index df9cbab915..3e238f97a0 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -366,7 +366,7 @@ class IndexedArrayAnalysis {
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to
// unconditionally add to the regular HLO pass pipeline.
-class IndexedArrayAnalysisPrinterPass : public HloPassInterface {
+class IndexedArrayAnalysisPrinterPass : public HloModulePass {
public:
absl::string_view name() const override;
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h
index efa8ed3abc..e20af08fb7 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/inliner.h
@@ -24,7 +24,7 @@ namespace xla {
// A pass which performs inlining. Which can result, for example, in functions
// that were previously being mapped by Map instead directly applied to the
// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
-class Inliner : public HloPassInterface {
+class Inliner : public HloModulePass {
public:
~Inliner() override = default;
absl::string_view name() const override { return "inline"; }
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
index 93a74dbfa6..7e967f035c 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -35,7 +35,7 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using InlinerTest = HloTestBase;
+using InlinerTest = HloVerifiedTestBase;
// Test that `map` with `max` is transformed to `max`
TEST_F(InlinerTest, MapMax) {
@@ -64,12 +64,12 @@ TEST_F(InlinerTest, MapMax) {
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Maximum(lhs, rhs));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
@@ -98,12 +98,12 @@ TEST_F(InlinerTest, MapConstant) {
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
root = hlo_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Broadcast(op::Constant()));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
@@ -136,12 +136,12 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
hlo_module->AddEntryComputation(std::move(computation));
Inliner inliner;
- EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Subtract(rhs, lhs));
// Verify execution on CPU.
- auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 8c907eae0c..3fdc2cee9a 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -295,6 +296,138 @@ InstructionFusion::ComputeGloballyUnfusible(
return do_not_duplicate;
}
+namespace {
+
+// A FusionQueue that uses reverse post order.
+//
+// We want to be able to remove arbitrary instructions from the post order and
+// also compare positions of instructions in the post order. To make this
+// possible, create vector of instructions in post order and create a map from
+// HloInstruction* to the instruction's index in the vector. An instruction is
+// "removed" from the vector by setting it's element to nullptr.
+class ReversePostOrderFusionQueue : public FusionQueue {
+ public:
+ explicit ReversePostOrderFusionQueue(HloComputation* computation) {
+ post_order_ = computation->MakeInstructionPostOrder();
+
+ for (size_t i = 0; i < post_order_.size(); ++i) {
+ InsertOrDie(&post_order_index_, post_order_[i], i);
+ }
+ }
+
+ std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() override {
+ // Instructions are "removed" from the post order by nulling out the element
+ // in the vector, so if the pointer is null, continue to the next
+ // instruction in the sort.
+ while (!post_order_.empty() && post_order_.back() == nullptr) {
+ post_order_.pop_back();
+ }
+ if (post_order_.empty()) {
+ return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
+ }
+ // We want to iterate in reverse post order, so remove from the back of the
+ // vector.
+ HloInstruction* instruction = post_order_.back();
+ post_order_.pop_back();
+
+ CHECK(instruction != nullptr);
+ // Remove instruction from the index map to ensure the vector and map stay
+ // consistent.
+ post_order_index_.erase(instruction);
+
+ // Consider each operand of this instruction for fusion into this
+ // instruction. We want to consider the operands in a particular order to
+ // avoid creating duplicate instruction clones in the fusion instruction.
+ // For example, consider the following expression:
+ //
+ // A = ...
+ // B = op(A)
+ // C = op(A, B)
+ //
+ // If we are considering the operands of C for fusion into C. We might
+ // fuse A or B first. If we fuse A first, we get:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // C' = op(A', B) }
+ //
+ // Where A' and C' are clones of A and C, respectively. Now only B is an
+ // operand of the fusion instruction C_fusion, so then we fuse B:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // B' = op(A)
+ // C' = op(A', B') }
+ //
+ // Now A is an operand of C_fusion again, so we then fuse A (again!):
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // A" = ..
+ // B' = op(A")
+ // C' = op(A', B') }
+ //
+ // We prevent this duplication by considering the operands in the order
+ // they appear int the queue. In the example, this ensures that B will be
+ // considered before A.
+ //
+ // We store the original indices of the operands to pass to ShouldFuse.
+ std::vector<int64> sorted_operand_numbers;
+ sorted_operand_numbers.reserve(instruction->operands().size());
+ for (int i = 0; i < instruction->operands().size(); ++i) {
+ // This will happen if we have two possible instructions to fuse the
+ // same operand into; once the operand is fused into one instruction,
+ // the other instruction will get a new get-tuple-element as its
+ // operand, which is not in the queue.
+ // TODO(tjoerg): Look into fusing past these multi-output fuse points.
+ if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
+ continue;
+ }
+ sorted_operand_numbers.push_back(i);
+ }
+ std::sort(
+ sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
+ [&](int64 i, int64 j) {
+ // Instructions with higher priority in the queue come first.
+ return (
+ FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
+ FindOrDie(post_order_index_, instruction->mutable_operand(j)));
+ });
+ return std::make_pair(instruction, sorted_operand_numbers);
+ }
+
+ void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) override {
+ // Fusing an instruction into a fusion instruction can change the operand
+ // set of the fusion instruction. For simplicity just re-enqueue the
+ // instruction and reconsider it for further fusion in the next iteration.
+ InsertOrDie(&post_order_index_, fusion, post_order_.size());
+ post_order_.push_back(fusion);
+ }
+
+ void RemoveInstruction(HloInstruction* instruction) override {
+ post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
+ post_order_index_.erase(instruction);
+ }
+
+ private:
+ std::vector<HloInstruction*> post_order_;
+ tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
+};
+
+} // namespace
+
+std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer) {
+ return absl::make_unique<ReversePostOrderFusionQueue>(computation);
+}
+
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
VLOG(2) << "Before instruction fusion:";
XLA_VLOG_LINES(2, module->ToString());
@@ -306,111 +439,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
computation_ = computation;
reachability_ = computation_->ComputeReachability();
- // We want to be able to remove arbitrary instructions from the post order
- // and also compare positions of instructions in the post order. To make
- // this possible, create vector of instructions in post order and create a
- // map from HloInstruction* to the instruction's index in the vector. An
- // instruction is "removed" from the vector by setting it's element to
- // nullptr.
- std::vector<HloInstruction*> post_order =
- computation_->MakeInstructionPostOrder();
-
- tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
- for (size_t i = 0; i < post_order.size(); ++i) {
- InsertOrDie(&post_order_index, post_order[i], i);
- }
-
- HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order);
+ HloInstructionSet do_not_duplicate =
+ ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
+ auto fusion_queue =
+ GetFusionQueue(computation_, [&](HloInstruction* producer) {
+ return do_not_duplicate.count(producer) > 0;
+ });
// Instruction fusion effectively fuses edges in the computation graph
// (producer instruction -> consumer instruction) so we iterate over all
// edges. When we fuse an edge, we create a copy of the producer inside the
// fusion instruction.
- while (!post_order.empty()) {
- // We want to iterate in reverse post order, so remove from the back of
- // the vector.
- HloInstruction* instruction = post_order.back();
- post_order.pop_back();
-
- // Instructions are "removed" from the post order by nulling out the
- // element in the vector, so if the pointer is null, continue to the next
- // instruction in the sort.
+ while (true) {
+ auto next_entry =
+ fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
+ auto instruction = next_entry.first;
if (instruction == nullptr) {
- continue;
+ break;
}
- // Remove instruction from the index map to ensure the vector and map stay
- // consistent.
- post_order_index.erase(instruction);
-
if (!instruction->IsFusible() &&
instruction->opcode() != HloOpcode::kFusion) {
continue;
}
- // Consider each operand of this instruction for fusion into this
- // instruction. We want to consider the operands in a particular order to
- // avoid creating duplicate instruction clones in the fusion instruction.
- // For example, consider the following expression:
- //
- // A = ...
- // B = op(A)
- // C = op(A, B)
- //
- // If we are considering the operands of C for fusion into C. We might
- // fuse A or B first. If we fuse A first, we get:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // C' = op(A', B) }
- //
- // Where A' and C' are clones of A and C, respectively. Now only B is an
- // operand of the fusion instruction C_fusion, so then we fuse B:
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // B' = op(A)
- // C' = op(A', B') }
- //
- // Now A is an operand of C_fusion again, so we then fuse A (again!):
- //
- // A = ...
- // B = op(A)
- // C_fusion = { A' = ...
- // A" = ..
- // B' = op(A")
- // C' = op(A', B') }
- //
- // We prevent this duplication by considering the operands in the reverse
- // order they appear in the instruction post order. In the example, this
- // ensures that B will be considered before A.
- //
- // We store the original indices of the operands to pass to ShouldFuse.
- std::vector<int64> sorted_operand_numbers;
- sorted_operand_numbers.reserve(instruction->operands().size());
- for (int i = 0; i < instruction->operands().size(); ++i) {
- // This will happen if we have two possible instructions to fuse the
- // same operand into; once the operand is fused into one instruction,
- // the other instruction will get a new get-tuple-element as its
- // operand, which is not in the post-order index.
- // TODO(tjoerg): Look into fusing past these multi-output fuse points.
- if (post_order_index.find(instruction->mutable_operand(i)) ==
- post_order_index.end()) {
- continue;
- }
- sorted_operand_numbers.push_back(i);
- }
- std::sort(
- sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
- [&](int64 i, int64 j) {
- // Instructions with higher indices in the post order come
- // first.
- return (
- FindOrDie(post_order_index, instruction->mutable_operand(i)) >
- FindOrDie(post_order_index, instruction->mutable_operand(j)));
- });
+ std::vector<int64>& sorted_operand_numbers = next_entry.second;
for (int64 i : sorted_operand_numbers) {
HloInstruction* operand = instruction->mutable_operand(i);
@@ -425,32 +478,31 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
// TODO(tjoerg): Consider making multi-output fusion the default.
if (ShouldFuse(instruction, i) &&
do_not_duplicate.count(operand) == 0) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = Fuse(operand, instruction);
} else if (ShouldFuseIntoMultiOutput(instruction, i) &&
!MultiOutputFusionCreatesCycle(operand, instruction)) {
+ fusion_queue->PreFusion(operand, instruction);
fusion_instruction = FuseIntoMultiOutput(operand, instruction);
} else {
continue;
}
- // Fusing an instruction into a fusion instruction can change the
- // operand set of the fusion instruction. For simplicity just push the
- // instruction to the top of the post_order and reconsider it for
- // further fusion in the next iteration of the outer loop.
- post_order.push_back(fusion_instruction);
- InsertOrDie(&post_order_index, fusion_instruction,
- post_order.size() - 1);
+ fusion_queue->OnFusingInstruction(fusion_instruction, operand,
+ instruction);
changed = true;
if (operand->user_count() == 0) {
- // Operand is now dead. Remove from post order by setting its
- // location to nullptr.
- post_order[FindOrDie(post_order_index, operand)] = nullptr;
- post_order_index.erase(operand);
-
+ do_not_duplicate.erase(operand);
+ // Operand is now dead. Remove from queue.
+ fusion_queue->RemoveInstruction(operand);
// Remove from computation.
TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
}
+
+ if (fusion_instruction != instruction) {
+ do_not_duplicate.erase(instruction);
+ }
break;
}
}
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index 00b658959a..7e1196fb7f 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -24,12 +24,39 @@ limitations under the License.
namespace xla {
+// A queue interface that allows implementations to choose fusion candidates in
+// custom order.
+class FusionQueue {
+ public:
+ FusionQueue() = default;
+ virtual ~FusionQueue() = default;
+
+ // Dequeues the next fusion candidates: a consumer and the list of producers
+ // as operand indices.
+ virtual std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
+
+ // A callback passed to the queue implementation right before the producer is
+ // fused into the consumer.
+ virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
+
+ // A callback passed to the queue implementation right after the fusion is
+ // created. Note that original_producer could have been destroyed.
+ virtual void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) {}
+
+ // A callback passed to the queue implementation to notify the removal of an
+ // instruction.
+ virtual void RemoveInstruction(HloInstruction* instruction) = 0;
+};
+
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
// code generation. Derived classes define ShouldFuse method to select which
// instructions to fuse.
-class InstructionFusion : public HloPassInterface {
+class InstructionFusion : public HloModulePass {
public:
explicit InstructionFusion(
std::function<bool(const HloInstruction& instruction)> is_expensive,
@@ -48,6 +75,13 @@ class InstructionFusion : public HloPassInterface {
static bool IsExpensive(const HloInstruction& instruction);
protected:
+ // Returns a FusionQueue that implements custom order of instructions being
+ // fused. The default implementation processes consumers in reverse post
+ // order.
+ virtual std::unique_ptr<FusionQueue> GetFusionQueue(
+ HloComputation* computation,
+ const std::function<bool(HloInstruction*)>& skip_producer);
+
// Returns whether the given producer instruction should be fused into the
// given consumer instruction. producer is necessarily an operand of consumer.
// Derived classes should define this method to specify which instructions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index cf545031d3..e29c199c42 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -281,7 +281,7 @@ class ChannelLayoutConstraints {
// HLO pass which assigns layouts to all instructions in the HLO module while
// satisfying all necessary invariants and minimizing cost.
-class LayoutAssignment : public HloPassInterface {
+class LayoutAssignment : public HloModulePass {
public:
// entry_computation_layout is modified to populate a layout for the result in
// the case that no particular layout is requested.
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index d2c52651c4..0344626b26 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -44,7 +44,7 @@ namespace xla {
// Note that the reachability map is updated based on the original computation.
// This works because the reachability is monotonically increasing with
// instruction fusion.
-class MultiOutputFusion : public HloPassInterface {
+class MultiOutputFusion : public HloModulePass {
public:
MultiOutputFusion(int64 fuel) : fuel_(fuel) {}
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index bd8fb17a23..ac2f79674f 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -39,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) {
}
/*static*/ string NameUniquer::GetSanitizedName(const string& name) {
+ if (name.empty()) {
+ return "";
+ }
string result = name;
- CHECK(!result.empty()) << "name should not be empty";
char c = static_cast<unsigned char>(result[0]);
if (!isalpha(c) && c != '_') {
result[0] = '_';
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
index 4869db79e7..380cde0e6a 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher.h
+++ b/tensorflow/compiler/xla/service/pattern_matcher.h
@@ -17,8 +17,12 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
#include "absl/strings/string_view.h"
+#include "absl/utility/utility.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -116,15 +120,82 @@ namespace xla {
// .WithOperand(1, Op(&c))
// .WithOperand(2, Op(&d))
//
+
+struct MatchOption {
+ // If true, actually capture matched item into the user pointer.
+ bool capture;
+};
+
template <typename Value, typename Pattern>
-bool Match(Value* value, const Pattern& pattern) {
- return pattern.Match(value);
+bool Match(Value* value, const Pattern& pattern,
+ MatchOption option = {/*.capture=*/true}) {
+ if (option.capture) {
+ auto new_option = option;
+ new_option.capture = false;
+ if (!pattern.Match(value, new_option)) {
+ return false;
+ }
+ }
+ return pattern.Match(value, option);
}
namespace match {
namespace detail {
+template <typename Item, typename... Patterns>
+class AllOfPattern {
+ public:
+ explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
+
+ bool Match(const Item* item, MatchOption option) const {
+ bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ // This invariant is guaranteed by the top-level Match and AnyOf.
+ DCHECK(matched || !option.capture);
+ return matched;
+ }
+
+ bool Match(Item* item, MatchOption option) const {
+ bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ // This invariant is guaranteed by the top-level Match and AnyOf.
+ DCHECK(matched || !option.capture);
+ return matched;
+ }
+
+ private:
+ template <typename ItemType, size_t index>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, index>) const {
+ return std::get<index>(patterns_).Match(item, option) &&
+ MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
+ }
+
+ template <typename ItemType>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, sizeof...(Patterns)>) const {
+ return true;
+ }
+
+ std::tuple<Patterns...> patterns_;
+};
+
+} // namespace detail
+
+// Returns a pattern that represents the conjunction of all input patterns. All
+// patterns need to match in order to have the AllOf pattern match.
+//
+// TODO(timshen): Currently AllOf is still nested, e.g. AllOf<AllOf<A>, B> is
+// not AllOf<A, B>. We might want to flatten the AllOf type structure if the
+// C++ compile error message gets annoying.
+template <typename Item, typename... Patterns>
+detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf(
+ const Patterns&... patterns) {
+ return detail::AllOfPattern<typename std::remove_const<Item>::type,
+ Patterns...>(patterns...);
+}
+
+namespace detail {
+
template <typename LayoutType, typename Impl>
class LayoutPattern;
@@ -132,57 +203,61 @@ class LayoutPattern;
// nullptr.
class LayoutPatternBaseImpl {
public:
- bool Match(const ::xla::Layout* layout) const { return layout != nullptr; }
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return layout != nullptr;
+ }
};
// A LayoutPattern implementation that matches only if the layout equals a
// Layout proto.
-template <typename Previous>
class LayoutPatternEqualImpl {
public:
- explicit constexpr LayoutPatternEqualImpl(const Previous& previous,
- const ::xla::Layout* layout)
- : previous_(previous), layout_(layout) {}
+ explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout)
+ : layout_(layout) {}
- bool Match(const ::xla::Layout* layout) const {
- return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout);
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return LayoutUtil::Equal(*layout_, *layout);
}
private:
- Previous previous_;
const ::xla::Layout* layout_;
};
// A LayoutPattern implementation that matches only if the layout has a given
// format.
-template <typename Previous>
class LayoutPatternFormatImpl {
public:
- explicit constexpr LayoutPatternFormatImpl(const Previous& previous,
- Format format)
- : previous_(previous), format_(format) {}
+ explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {}
- bool Match(const ::xla::Layout* layout) const {
- return previous_.Match(layout) && layout->format() == format_;
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ return layout->format() == format_;
}
private:
- Previous previous_;
Format format_;
};
// A pattern that matches Layouts.
template <typename LayoutType, typename Impl>
class LayoutPattern {
+ private:
+ template <typename NewImpl>
+ LayoutPattern<LayoutType, AllOfPattern<::xla::Layout, Impl, NewImpl>>
+ AppendImpl(NewImpl new_impl) const {
+ return LayoutPattern<LayoutType,
+ AllOfPattern<::xla::Layout, Impl, NewImpl>>(
+ AllOf<Layout>(impl_, std::move(new_impl)), matched_layout_);
+ }
+
public:
explicit constexpr LayoutPattern(const Impl& impl,
LayoutType** matched_layout)
: impl_(impl), matched_layout_(matched_layout) {}
// Returns true and captures the layout iff it matches the pattern.
- bool Match(const ::xla::Layout* layout) const {
- if (impl_.Match(layout)) {
- if (matched_layout_) {
+ bool Match(const ::xla::Layout* layout, MatchOption option) const {
+ if (impl_.Match(layout, option)) {
+ if (option.capture && matched_layout_) {
*matched_layout_ = layout;
}
return true;
@@ -191,9 +266,9 @@ class LayoutPattern {
}
// Returns true and captures the layout iff it matches the pattern.
- bool Match(::xla::Layout* layout) const {
- if (impl_.Match(layout)) {
- if (matched_layout_) {
+ bool Match(::xla::Layout* layout, MatchOption option) const {
+ if (impl_.Match(layout, option)) {
+ if (option.capture && matched_layout_) {
*matched_layout_ = layout;
}
return true;
@@ -203,24 +278,21 @@ class LayoutPattern {
// Modifies the pattern to match only if the layout equals the given proto.
// The layout must outlive the returned pattern.
- constexpr LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>> EqualTo(
- const ::xla::Layout* layout) const {
- return LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>>(
- LayoutPatternEqualImpl<Impl>(impl_, layout), matched_layout_);
+ constexpr auto EqualTo(const ::xla::Layout* layout) const
+ -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) {
+ return AppendImpl(LayoutPatternEqualImpl(layout));
}
// Modifies the pattern to match only if the layout has a dense format.
- constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>
- WithDenseFormat() const {
- return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>(
- LayoutPatternFormatImpl<Impl>(impl_, DENSE), matched_layout_);
+ constexpr auto WithDenseFormat() const
+ -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) {
+ return AppendImpl(LayoutPatternFormatImpl(DENSE));
}
// Modifies the pattern to match only if the layout has a sparse format.
- constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>
- WithSparseFormat() const {
- return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>(
- LayoutPatternFormatImpl<Impl>(impl_, SPARSE), matched_layout_);
+ constexpr auto WithSparseFormat() const
+ -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) {
+ return AppendImpl(LayoutPatternFormatImpl(SPARSE));
}
private:
@@ -228,8 +300,72 @@ class LayoutPattern {
LayoutType** matched_layout_;
};
+template <typename Item, typename... Patterns>
+class AnyOfPattern {
+ public:
+ explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {}
+
+ bool Match(const Item* item, MatchOption option) const {
+ return MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ }
+
+ bool Match(Item* item, MatchOption option) const {
+ return MatchImpl(item, option, std::integral_constant<size_t, 0>());
+ }
+
+ private:
+ template <typename ItemType, size_t index>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, index>) const {
+ auto new_option = option;
+ new_option.capture = false;
+ // Try to match the sub-pattern without capturing behavior.
+ if (std::get<index>(patterns_).Match(item, new_option)) {
+ // Capture the branch.
+ if (option.capture) {
+ // TODO(timshen): Currently the behavior can be exponential. Optimize it
+ // with memoization or recording the matched sub-pattern index, if it
+ // takes too long to run.
+ //
+ // Specifically, the "memoization" approach is to create an empty
+ // container with the key (pattern, instruction), and value as whether
+ // matched or not.
+ //
+ // Alternatively, we may run the pattern matching with captures off, but
+ // instead record a "trace" somewhere, indicating how exactly the
+ // pattern matches the input. For example, the trace information for
+ // AnyOf will be a runtime number indicate which sub-pattern is matched.
+ // Then we run another pass to do captures only with the help of the
+ // trace.
+ bool ret = std::get<index>(patterns_).Match(item, option);
+ DCHECK(ret);
+ }
+ return true;
+ }
+ return MatchImpl(item, option, std::integral_constant<size_t, index + 1>());
+ }
+
+ template <typename ItemType>
+ bool MatchImpl(ItemType* item, MatchOption option,
+ std::integral_constant<size_t, sizeof...(Patterns)>) const {
+ return false;
+ }
+
+ std::tuple<Patterns...> patterns_;
+};
+
} // namespace detail
+// Returns a pattern that represents the logical disjunction of the input
+// patterns. The returned pattern matches from left to right, and stops on the
+// first match.
+template <typename Item, typename... Patterns>
+detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf(
+ const Patterns&... patterns) {
+ return detail::AnyOfPattern<typename std::remove_const<Item>::type,
+ Patterns...>(patterns...);
+}
+
// Creates a layout pattern that will capture the matched layout in the
// argument.
inline constexpr detail::LayoutPattern<const ::xla::Layout,
@@ -258,172 +394,145 @@ class ShapePattern;
// nullptr.
class ShapePatternBaseImpl {
public:
- bool Match(const ::xla::Shape* shape) const { return shape != nullptr; }
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return shape != nullptr;
+ }
};
// A ShapePattern implementation that matches only if the shape equals a Shape
// proto.
-template <typename Previous>
class ShapePatternEqualImpl {
public:
- explicit constexpr ShapePatternEqualImpl(const Previous& previous,
- const ::xla::Shape* shape)
- : previous_(previous), shape_(shape) {}
+ explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape)
+ : shape_(shape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Equal(*shape_, *shape);
}
private:
- Previous previous_;
const ::xla::Shape* shape_;
};
// A ShapePattern implementation that matches only if the shape is compatible to
// a Shape proto.
-template <typename Previous>
class ShapePatternCompatibleImpl {
public:
- explicit constexpr ShapePatternCompatibleImpl(const Previous& previous,
- const ::xla::Shape* shape)
- : previous_(previous), shape_(shape) {}
+ explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape)
+ : shape_(shape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Compatible(*shape_, *shape);
}
private:
- Previous previous_;
const ::xla::Shape* shape_;
};
// A ShapePattern implementation that matches only if the shape has a given
// element type.
-template <typename Previous>
class ShapePatternElementTypeImpl {
public:
- explicit constexpr ShapePatternElementTypeImpl(const Previous& previous,
- PrimitiveType element_type)
- : previous_(previous), element_type_(element_type) {}
+ explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type)
+ : element_type_(element_type) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && shape->element_type() == element_type_;
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return shape->element_type() == element_type_;
}
private:
- Previous previous_;
PrimitiveType element_type_;
};
// A ShapePattern implementation that matches only if the shape is scalar.
-template <typename Previous>
class ShapePatternIsScalarImpl {
public:
- explicit constexpr ShapePatternIsScalarImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsScalarImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsScalar(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsScalar(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape is an array
-template <typename Previous>
class ShapePatternIsArrayImpl {
public:
- explicit constexpr ShapePatternIsArrayImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsArrayImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsArray(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsArray(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape is a tuple.
-template <typename Previous>
class ShapePatternIsTupleImpl {
public:
- explicit constexpr ShapePatternIsTupleImpl(const Previous& previous)
- : previous_(previous) {}
+ explicit constexpr ShapePatternIsTupleImpl() {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IsTuple(*shape);
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IsTuple(*shape);
}
-
- private:
- Previous previous_;
};
// A ShapePattern implementation that matches only if the shape has a given
// rank.
-template <typename Previous>
class ShapePatternRankImpl {
public:
- explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank)
- : previous_(previous), rank_(rank) {}
+ explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_;
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::Rank(*shape) == rank_;
}
private:
- Previous previous_;
int64 rank_;
};
// A ShapePattern implementation that matches only if the shape has a layout
// that matches a given pattern.
-template <typename Previous, typename LayoutType, typename LayoutImpl>
+template <typename LayoutType, typename LayoutImpl>
class ShapePatternLayoutImpl {
public:
explicit constexpr ShapePatternLayoutImpl(
- const Previous& previous,
const LayoutPattern<LayoutType, LayoutImpl>& layout)
- : previous_(previous), layout_(layout) {}
+ : layout_(layout) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) &&
- layout_.Match(&shape->layout());
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return LayoutUtil::HasLayout(*shape) &&
+ layout_.Match(&shape->layout(), option);
}
- bool Match(Shape* shape) const {
- return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) &&
- layout_.Match(shape->mutable_layout());
+ bool Match(Shape* shape, MatchOption option) const {
+ return LayoutUtil::HasLayout(*shape) &&
+ layout_.Match(shape->mutable_layout(), option);
}
private:
- Previous previous_;
LayoutPattern<LayoutType, LayoutImpl> layout_;
};
// A ShapePattern implementation that matches only if the shape has a subshape
// that matches a given pattern.
-template <typename Previous, typename SubshapeType, typename SubshapeImpl>
+template <typename SubshapeType, typename SubshapeImpl>
class ShapePatternSubshapeImpl {
public:
explicit ShapePatternSubshapeImpl(
- const Previous& previous, ShapeIndexView index,
+ ShapeIndexView index,
const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
- : previous_(previous), index_(index), subshape_(subshape) {}
+ : index_(index), subshape_(subshape) {}
- bool Match(const ::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) &&
- subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_));
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IndexIsValid(*shape, index_) &&
+ subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option);
}
- bool Match(::xla::Shape* shape) const {
- return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) &&
- subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_));
+ bool Match(::xla::Shape* shape, MatchOption option) const {
+ return ShapeUtil::IndexIsValid(*shape, index_) &&
+ subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_),
+ option);
}
private:
- Previous previous_;
ShapeIndexView index_;
ShapePattern<SubshapeType, SubshapeImpl> subshape_;
};
@@ -431,14 +540,22 @@ class ShapePatternSubshapeImpl {
// A pattern that matches Shapes.
template <typename ShapeType, typename Impl>
class ShapePattern {
+ private:
+ template <typename NewImpl>
+ ShapePattern<ShapeType, AllOfPattern<::xla::Shape, Impl, NewImpl>> AppendImpl(
+ NewImpl new_impl) const {
+ return ShapePattern<ShapeType, AllOfPattern<::xla::Shape, Impl, NewImpl>>(
+ AllOf<Shape>(impl_, std::move(new_impl)), matched_shape_);
+ }
+
public:
explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
: impl_(impl), matched_shape_(matched_shape) {}
// Returns true and captures the shape iff it matches the pattern.
- bool Match(const ::xla::Shape* shape) const {
- if (impl_.Match(shape)) {
- if (matched_shape_) {
+ bool Match(const ::xla::Shape* shape, MatchOption option) const {
+ if (impl_.Match(shape, option)) {
+ if (option.capture && matched_shape_) {
*matched_shape_ = shape;
}
return true;
@@ -447,9 +564,9 @@ class ShapePattern {
}
// Returns true and captures the shape iff it matches the pattern.
- bool Match(::xla::Shape* shape) const {
- if (impl_.Match(shape)) {
- if (matched_shape_) {
+ bool Match(::xla::Shape* shape, MatchOption option) const {
+ if (impl_.Match(shape, option)) {
+ if (option.capture && matched_shape_) {
*matched_shape_ = shape;
}
return true;
@@ -459,108 +576,90 @@ class ShapePattern {
// Modifies the pattern to match only if the shape equals the given proto.
// The layout must outlive the returned pattern.
- constexpr ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>> EqualTo(
- const ::xla::Shape* shape) const {
- return ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>>(
- ShapePatternEqualImpl<Impl>(impl_, shape), matched_shape_);
+ constexpr auto EqualTo(const ::xla::Shape* shape) const
+ -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) {
+ return AppendImpl(ShapePatternEqualImpl(shape));
}
// Modifies the pattern to match only if the shape is compatible to the given
// proto. The layout must outlive the returned pattern.
- constexpr ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>
- CompatibleTo(const ::xla::Shape* shape) const {
- return ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>(
- ShapePatternCompatibleImpl<Impl>(impl_, shape), matched_shape_);
+ constexpr auto CompatibleTo(const ::xla::Shape* shape) const
+ -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) {
+ return AppendImpl(ShapePatternCompatibleImpl(shape));
}
// Modifies the pattern to match only if the shape has the given element type.
- constexpr ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>
- WithElementType(PrimitiveType element_type) const {
- return ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>(
- ShapePatternElementTypeImpl<Impl>(impl_, element_type), matched_shape_);
+ constexpr auto WithElementType(PrimitiveType element_type) const
+ -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) {
+ return AppendImpl(ShapePatternElementTypeImpl(element_type));
}
// Modifies the pattern to match only if the shape is scalar.
- constexpr ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>> IsScalar()
- const {
- return ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>>(
- ShapePatternIsScalarImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsScalar() const
+ -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) {
+ return AppendImpl(ShapePatternIsScalarImpl());
}
// Modifies the pattern to match only if the shape is an array.
- constexpr ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>> IsArray()
- const {
- return ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>>(
- ShapePatternIsArrayImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsArray() const
+ -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) {
+ return AppendImpl(ShapePatternIsArrayImpl());
}
// Modifies the pattern to match only if the shape is a tuple.
- constexpr ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>> IsTuple()
- const {
- return ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>>(
- ShapePatternIsTupleImpl<Impl>(impl_), matched_shape_);
+ constexpr auto IsTuple() const
+ -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) {
+ return AppendImpl(ShapePatternIsTupleImpl());
}
// Modifies the pattern to match only if the shape has the given rank.
- constexpr ShapePattern<ShapeType, ShapePatternRankImpl<Impl>> WithRank(
- int64 rank) const {
- return ShapePattern<ShapeType, ShapePatternRankImpl<Impl>>(
- ShapePatternRankImpl<Impl>(impl_, rank), matched_shape_);
+ constexpr auto WithRank(int64 rank) const
+ -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) {
+ return AppendImpl(ShapePatternRankImpl(rank));
}
// Modifies the pattern to match only if the shape has a layout that matches
// the given pattern.
template <typename LayoutType, typename LayoutImpl>
- constexpr ShapePattern<ShapeType,
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>
- WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
- return ShapePattern<ShapeType,
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>(
- ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>(impl_, layout),
- matched_shape_);
- }
-
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternEqualImpl<LayoutPatternBaseImpl>>>
- WithLayoutEqualTo(const ::xla::Layout* layout) const {
+ auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const
+ -> decltype(this->AppendImpl(
+ ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) {
+ return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout));
+ }
+
+ constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const
+ -> decltype(this->WithLayout(Layout().EqualTo(layout))) {
return WithLayout(Layout().EqualTo(layout));
}
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternFormatImpl<LayoutPatternBaseImpl>>>
- IsDenseArray() const {
+ constexpr auto IsDenseArray() const
+ -> decltype(this->WithLayout(Layout().WithDenseFormat())) {
return WithLayout(Layout().WithDenseFormat());
}
- constexpr ShapePattern<
- ShapeType,
- ShapePatternLayoutImpl<Impl, const ::xla::Layout,
- LayoutPatternFormatImpl<LayoutPatternBaseImpl>>>
- IsSparseArray() const {
+ constexpr auto IsSparseArray() const
+ -> decltype(this->WithLayout(Layout().WithSparseFormat())) {
return WithLayout(Layout().WithSparseFormat());
}
// Modifies the pattern to match only if the shape has a subshape that matches
// the given pattern.
template <typename SubshapeType, typename SubshapeImpl>
+ auto WithSubshape(ShapeIndexView index,
+ const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
+ const -> decltype(this->AppendImpl(
+ ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index,
+ subshape))) {
+ return AppendImpl(
+ ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape));
+ }
+
ShapePattern<ShapeType,
- ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>
- WithSubshape(ShapeIndexView index,
- const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
- return ShapePattern<
- ShapeType, ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>(
- ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>(impl_, index,
- subshape),
- matched_shape_);
- }
-
- ShapePattern<ShapeType, ShapePatternSubshapeImpl<
- Impl, const ::xla::Shape,
- ShapePatternEqualImpl<ShapePatternBaseImpl>>>
+ AllOfPattern<Shape, Impl,
+ ShapePatternSubshapeImpl<
+ const ::xla::Shape,
+ AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
+ ShapePatternEqualImpl>>>>
WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
return WithSubshape(index,
ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
@@ -568,9 +667,12 @@ class ShapePattern {
.EqualTo(shape));
}
- ShapePattern<ShapeType, ShapePatternSubshapeImpl<
- Impl, const ::xla::Shape,
- ShapePatternCompatibleImpl<ShapePatternBaseImpl>>>
+ ShapePattern<ShapeType,
+ AllOfPattern<Shape, Impl,
+ ShapePatternSubshapeImpl<
+ const ::xla::Shape,
+ AllOfPattern<::xla::Shape, ShapePatternBaseImpl,
+ ShapePatternCompatibleImpl>>>>
WithSubshapeCompatibleTo(ShapeIndexView index,
const ::xla::Shape* shape) const {
return WithSubshape(index,
@@ -611,159 +713,169 @@ class HloInstructionPattern;
// instruction is not nullptr.
class HloInstructionPatternBaseImpl {
public:
- bool Match(const ::xla::HloInstruction* inst) const {
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
return inst != nullptr;
}
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a given name.
-template <typename Previous>
class HloInstructionPatternNameImpl {
public:
- explicit HloInstructionPatternNameImpl(const Previous& previous,
- absl::string_view name)
- : previous_(previous), name_(name) {}
+ explicit HloInstructionPatternNameImpl(absl::string_view name)
+ : name_(name) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->name() == name_;
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->name() == name_;
}
private:
- Previous previous_;
absl::string_view name_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a given opcode.
-template <typename Previous>
class HloInstructionPatternOpcodeImpl {
public:
- explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous,
- HloOpcode opcode,
+ explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode,
bool invert)
- : previous_(previous), opcode_(opcode), invert_(invert) {}
+ : opcode_(opcode), invert_(invert) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_));
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return (invert_ ^ (inst->opcode() == opcode_));
}
private:
- Previous previous_;
HloOpcode opcode_;
bool invert_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has a shape that matches a given pattern.
-template <typename Previous, typename ShapeType, typename ShapeImpl>
+template <typename ShapeType, typename ShapeImpl>
class HloInstructionPatternShapeImpl {
public:
explicit constexpr HloInstructionPatternShapeImpl(
- const Previous& previous, const ShapePattern<ShapeType, ShapeImpl>& shape)
- : previous_(previous), shape_(shape) {}
+ const ShapePattern<ShapeType, ShapeImpl>& shape)
+ : shape_(shape) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && shape_.Match(&inst->shape());
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return shape_.Match(&inst->shape(), option);
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && shape_.Match(inst->mutable_shape());
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return shape_.Match(inst->mutable_shape(), option);
}
private:
- Previous previous_;
ShapePattern<ShapeType, ShapeImpl> shape_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// has an operand that matches a given pattern.
-template <typename Previous, typename OperandType, typename OperandImpl>
+template <typename OperandType, typename OperandImpl>
class HloInstructionPatternOperandImpl {
public:
explicit constexpr HloInstructionPatternOperandImpl(
- const Previous& previous, int64 operand_index,
+ int64 operand_index,
const HloInstructionPattern<OperandType, OperandImpl>& operand)
- : previous_(previous), operand_index_(operand_index), operand_(operand) {}
+ : operand_index_(operand_index), operand_(operand) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && operand_index_ < inst->operand_count() &&
- operand_.Match(inst->operand(operand_index_));
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return operand_index_ < inst->operand_count() &&
+ operand_.Match(inst->operand(operand_index_), option);
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && operand_index_ < inst->operand_count() &&
- operand_.Match(inst->mutable_operand(operand_index_));
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return operand_index_ < inst->operand_count() &&
+ operand_.Match(inst->mutable_operand(operand_index_), option);
}
private:
- Previous previous_;
int64 operand_index_;
HloInstructionPattern<OperandType, OperandImpl> operand_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// is a fusion node with a particular kind.
-template <typename Previous>
class HloInstructionPatternFusionKindImpl {
public:
explicit constexpr HloInstructionPatternFusionKindImpl(
- const Previous& previous, ::xla::HloInstruction::FusionKind kind)
- : previous_(previous), kind_(kind) {}
+ ::xla::HloInstruction::FusionKind kind)
+ : kind_(kind) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion &&
- inst->fusion_kind() == kind_;
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_;
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion &&
- inst->fusion_kind() == kind_;
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_;
}
private:
- Previous previous_;
::xla::HloInstruction::FusionKind kind_;
};
// An HloInstructionPattern implementation that matches only if the instruction
// is a kGetTupleElement with a particular tuple index.
-template <typename Previous>
class HloInstructionPatternTupleIndexImpl {
public:
- explicit constexpr HloInstructionPatternTupleIndexImpl(
- const Previous& previous, int64 tuple_index)
- : previous_(previous), tuple_index_(tuple_index) {}
+ explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index)
+ : tuple_index_(tuple_index) {}
- bool Match(const ::xla::HloInstruction* inst) const {
- return previous_.Match(inst) &&
- inst->opcode() == HloOpcode::kGetTupleElement &&
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kGetTupleElement &&
inst->tuple_index() == tuple_index_;
}
- bool Match(::xla::HloInstruction* inst) const {
- return previous_.Match(inst) &&
- inst->opcode() == HloOpcode::kGetTupleElement &&
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ return inst->opcode() == HloOpcode::kGetTupleElement &&
inst->tuple_index() == tuple_index_;
}
private:
- Previous previous_;
int64 tuple_index_;
};
+template <typename ItemType, typename Predicate>
+class HloPredicatePatternImpl {
+ public:
+ explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {}
+
+ bool Match(const ItemType* item, MatchOption option) const {
+ return pred_(item);
+ }
+
+ bool Match(ItemType* item, MatchOption option) const { return pred_(item); }
+
+ private:
+ Predicate pred_;
+};
+
+struct PatternFriend;
+
// A pattern that matches HloInstructions.
template <typename HloInstructionType, typename Impl>
class HloInstructionPattern {
+ private:
+ template <typename NewImpl>
+ HloInstructionPattern<HloInstructionType,
+ AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>
+ AppendImpl(NewImpl new_impl) const {
+ return HloInstructionPattern<
+ HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>(
+ AllOf<HloInstruction>(impl_, std::move(new_impl)), matched_inst_);
+ }
+
public:
explicit constexpr HloInstructionPattern(const Impl& impl,
HloInstructionType** matched_inst)
: impl_(impl), matched_inst_(matched_inst) {}
// Returns true and captures the instruction iff it matches the pattern.
- bool Match(const ::xla::HloInstruction* inst) const {
- if (impl_.Match(inst)) {
- if (matched_inst_) {
+ bool Match(const ::xla::HloInstruction* inst, MatchOption option) const {
+ if (impl_.Match(inst, option)) {
+ if (option.capture && matched_inst_) {
*matched_inst_ = inst;
}
return true;
@@ -772,9 +884,9 @@ class HloInstructionPattern {
}
// Returns true and captures the instruction iff it matches the pattern.
- bool Match(::xla::HloInstruction* inst) const {
- if (impl_.Match(inst)) {
- if (matched_inst_) {
+ bool Match(::xla::HloInstruction* inst, MatchOption option) const {
+ if (impl_.Match(inst, option)) {
+ if (option.capture && matched_inst_) {
*matched_inst_ = inst;
}
return true;
@@ -783,102 +895,87 @@ class HloInstructionPattern {
}
// Modifies the pattern to match only if the instruction has the given name.
- HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>
- WithName(absl::string_view name) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternNameImpl<Impl>>(
- HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_);
+ auto WithName(absl::string_view name) const
+ -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) {
+ return AppendImpl(HloInstructionPatternNameImpl(name));
}
// Modifies the pattern to match only if the instruction has the given opcode.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- WithOpcode(HloOpcode opcode) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>(
- HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, false),
- matched_inst_);
+ auto WithOpcode(HloOpcode opcode) const
+ -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
+ false))) {
+ return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false));
}
// Modifies the pattern to match only if the instruction does not have the
// given opcode.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- WithoutOpcode(HloOpcode opcode) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>(
- HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, true),
- matched_inst_);
+ auto WithoutOpcode(HloOpcode opcode) const
+ -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode,
+ true))) {
+ return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true));
}
// Modifies the pattern to match only if the instruction is a constant.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- IsConstant() const {
+ constexpr auto IsConstant() const
+ -> decltype(this->WithOpcode(HloOpcode::kConstant)) {
return WithOpcode(HloOpcode::kConstant);
}
// Modifies the pattern to match only if the instruction is not a constant.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternOpcodeImpl<Impl>>
- IsNonConstant() const {
+ constexpr auto IsNonConstant() const
+ -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) {
return WithoutOpcode(HloOpcode::kConstant);
}
// Modifies the pattern to match only if the instruction has a shape that
// matches the given pattern.
template <typename ShapeType, typename ShapeImpl>
- constexpr HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>
- WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape) const {
- return HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>(
- HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>(impl_,
- shape),
- matched_inst_);
+ constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape)
+ const -> decltype(this->AppendImpl(
+ HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) {
+ return AppendImpl(
+ HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape));
}
// Modifies the pattern to match only if the instruction has an operand that
// matches the given pattern.
template <typename OperandType, typename OperandImpl>
- constexpr HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>
- WithOperand(
+ constexpr auto WithOperand(
int64 operand_index,
- const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
- return HloInstructionPattern<
- HloInstructionType,
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>(
- HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>(
- impl_, operand_index, operand),
- matched_inst_);
+ const HloInstructionPattern<OperandType, OperandImpl>& operand) const
+ -> decltype(this->AppendImpl(
+ HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
+ operand_index, operand))) {
+ return AppendImpl(
+ HloInstructionPatternOperandImpl<OperandType, OperandImpl>(
+ operand_index, operand));
}
// Modifies the pattern to match only if the instruction is a fusion node with
// the given kind.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternFusionKindImpl<Impl>>
- WithFusionKind(HloInstruction::FusionKind kind) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternFusionKindImpl<Impl>>(
- HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_);
+ constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const
+ -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) {
+ return AppendImpl(HloInstructionPatternFusionKindImpl(kind));
}
// Modifies the pattern to match only if the instruction is a
// get-tuple-element with the given tuple index.
- constexpr HloInstructionPattern<HloInstructionType,
- HloInstructionPatternTupleIndexImpl<Impl>>
- WithTupleIndex(int64 tuple_index) const {
- return HloInstructionPattern<HloInstructionType,
- HloInstructionPatternTupleIndexImpl<Impl>>(
- HloInstructionPatternTupleIndexImpl<Impl>(impl_, tuple_index),
- matched_inst_);
+ constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype(
+ this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) {
+ return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index));
}
private:
+ template <typename Predicate>
+ constexpr auto WithPredicate(Predicate pred) const -> decltype(
+ this->AppendImpl(HloPredicatePatternImpl<HloInstruction, Predicate>(
+ std::move(pred)))) {
+ return AppendImpl(
+ HloPredicatePatternImpl<HloInstruction, Predicate>(std::move(pred)));
+ }
+
+ friend struct PatternFriend;
+
Impl impl_;
HloInstructionType** matched_inst_;
};
@@ -1005,31 +1102,50 @@ XLA_UNOP_PATTERN(Transpose)
.WithOperand(0, std::forward<Lhs>(lhs)) \
.WithOperand(1, std::forward<Rhs>(rhs)); \
}
-XLA_BINOP_PATTERN(Add)
+
+#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \
+ XLA_BINOP_PATTERN(NAME) \
+ \
+ template <typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \
+ ->decltype(AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs))) { \
+ return AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs)); \
+ } \
+ \
+ template <typename HloInstructionType, typename Lhs, typename Rhs> \
+ inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \
+ Rhs&& rhs) \
+ ->decltype(AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \
+ NAME(matched_inst, rhs, lhs))) { \
+ return AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \
+ NAME(matched_inst, rhs, lhs)); \
+ }
+XLA_COMMUTATIVE_BINOP_PATTERN(Add)
XLA_BINOP_PATTERN(Atan2)
XLA_BINOP_PATTERN(Divide)
XLA_BINOP_PATTERN(Complex)
XLA_BINOP_PATTERN(Dot)
-XLA_BINOP_PATTERN(Eq)
+XLA_COMMUTATIVE_BINOP_PATTERN(Eq)
XLA_BINOP_PATTERN(Gather)
XLA_BINOP_PATTERN(Ge)
XLA_BINOP_PATTERN(Gt)
XLA_BINOP_PATTERN(Le)
XLA_BINOP_PATTERN(Lt)
-XLA_BINOP_PATTERN(Maximum)
-XLA_BINOP_PATTERN(Minimum)
-XLA_BINOP_PATTERN(Multiply)
-XLA_BINOP_PATTERN(Ne)
+XLA_COMMUTATIVE_BINOP_PATTERN(Maximum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Minimum)
+XLA_COMMUTATIVE_BINOP_PATTERN(Multiply)
+XLA_COMMUTATIVE_BINOP_PATTERN(Ne)
XLA_BINOP_PATTERN(Outfeed)
XLA_BINOP_PATTERN(Power)
XLA_BINOP_PATTERN(Remainder)
XLA_BINOP_PATTERN(Send)
XLA_BINOP_PATTERN(Subtract)
-XLA_BINOP_PATTERN(And)
-XLA_BINOP_PATTERN(Or)
+XLA_COMMUTATIVE_BINOP_PATTERN(And)
+XLA_COMMUTATIVE_BINOP_PATTERN(Or)
XLA_BINOP_PATTERN(ShiftLeft)
XLA_BINOP_PATTERN(ShiftRightArithmetic)
XLA_BINOP_PATTERN(ShiftRightLogical)
+#undef XLA_COMMUTATIVE_BINOP_PATTERN
#undef XLA_BINOP_PATTERN
// Helpers for ternary instructions.
@@ -1070,6 +1186,30 @@ XLA_TERNOP_PATTERN(Clamp);
XLA_TERNOP_PATTERN(Select);
#undef XLA_TERNOP_PATTERN
+namespace detail {
+struct PatternFriend {
+ template <typename T>
+ static auto ConstantScalar(T constant) -> decltype(
+ Constant()
+ .WithShape(match::Shape().IsScalar())
+ .WithPredicate(
+ std::declval<std::function<bool(const HloInstruction*)>>())) {
+ std::function<bool(const HloInstruction*)> pred =
+ [constant](const HloInstruction* instr) {
+ const auto& literal = Cast<HloConstantInstruction>(instr)->literal();
+ auto status_or_const = LiteralUtil::CreateR0(constant).Convert(
+ literal.shape().element_type());
+ return status_or_const.ok() &&
+ literal == status_or_const.ConsumeValueOrDie();
+ };
+
+ return Constant()
+ .WithShape(match::Shape().IsScalar())
+ .WithPredicate(std::move(pred));
+ }
+};
+} // namespace detail
+
// Helpers for matching non-constant instructions.
inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
return Op().IsNonConstant();
@@ -1107,6 +1247,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg,
.WithTupleIndex(tuple_index);
}
+template <typename T>
+inline auto ConstantScalar(T constant)
+ -> decltype(detail::PatternFriend::ConstantScalar(constant)) {
+ return detail::PatternFriend::ConstantScalar(constant);
+}
+
} // namespace match
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index a530581c34..3ab7b7fd71 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -211,5 +211,188 @@ TEST(PatternMatcherTest, GetTupleElement) {
EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1)));
}
+TEST(PatternMatcherTest, AnyOf) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_TRUE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+ match::ConstantScalar(1))));
+ EXPECT_TRUE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1),
+ match::ConstantScalar(0))));
+ EXPECT_FALSE(
+ Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0),
+ match::ConstantScalar(2))));
+}
+
+TEST(PatternMatcherTest, ConstantScalar) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_TRUE(Match(root, match::ConstantScalar(42)));
+ EXPECT_FALSE(Match(root, match::ConstantScalar(41)));
+ EXPECT_FALSE(Match(root, match::ConstantScalar(0)));
+}
+
+TEST(PatternMatcherTest, NoMatchConstantScalar) {
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ EXPECT_FALSE(Match(root, match::ConstantScalar(42)));
+}
+
+TEST(PatternMatcherTest, MultiplyAnyOrder) {
+ using match::ConstantScalar;
+ using match::MultiplyAnyOrder;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ lhs = f16[] constant(42)
+ rhs = f16[] constant(52)
+ ROOT multiply = f16[] multiply(lhs, rhs)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+ const HloInstruction* instr;
+
+ EXPECT_TRUE(Match(
+ root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52))));
+ EXPECT_TRUE(Match(
+ root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42))));
+}
+
+TEST(PatternMatcherTest, AnyOfShortCircuit) {
+ using match::AnyOf;
+ using match::Multiply;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ lhs = f16[] constant(42)
+ rhs = f16[] constant(52)
+ ROOT multiply = f16[] multiply(lhs, rhs)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ {
+ const HloInstruction* mul = nullptr;
+ const HloInstruction* any = nullptr;
+
+ ASSERT_TRUE(Match(
+ root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any))));
+ EXPECT_NE(nullptr, mul);
+ EXPECT_EQ(nullptr, any);
+ }
+ {
+ const HloInstruction* mul = nullptr;
+ const HloInstruction* any = nullptr;
+
+ ASSERT_TRUE(Match(
+ root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op()))));
+ EXPECT_NE(nullptr, any);
+ EXPECT_EQ(nullptr, mul);
+ }
+}
+
+TEST(PatternMatcherTest, AllOf) {
+ using match::AllOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar());
+ auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16));
+ ASSERT_TRUE(Match(root, scalar_pattern));
+ ASSERT_TRUE(Match(root, f16_pattern));
+ EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern)));
+ EXPECT_TRUE(Match(root, AllOf<HloInstruction>(f16_pattern, scalar_pattern)));
+ EXPECT_FALSE(
+ Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern)));
+ EXPECT_FALSE(
+ Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern)));
+}
+
+TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) {
+ using match::AllOf;
+ using match::Broadcast;
+ using match::Constant;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ ROOT v = f16[] constant(42)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* constant = nullptr;
+ ASSERT_FALSE(
+ Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op()))));
+ EXPECT_EQ(nullptr, constant);
+ ASSERT_TRUE(Match(root, Constant(&constant)));
+ EXPECT_NE(nullptr, constant);
+}
+
+TEST(PatternMatcherTest, TestNoCapture) {
+ using match::Constant;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ ROOT v = f16[] constant(42)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* constant = nullptr;
+ ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false}));
+ EXPECT_EQ(nullptr, constant);
+}
+
+TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) {
+ using match::Add;
+ using match::AddAnyOrder;
+ using match::AnyOf;
+ using match::Op;
+
+ constexpr char kModuleStr[] = R"(
+ HloModule test_module
+ ENTRY test {
+ u = f16[] parameter(0)
+ v = f16[] parameter(1)
+ ROOT add = f16[] add(u, v)
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
+ auto* root = hlo_module->entry_computation()->root_instruction();
+
+ const HloInstruction* addend0 = nullptr;
+ const HloInstruction* addend1 = nullptr;
+ const HloInstruction* addend2 = nullptr;
+ auto add2_pattern = Add(Op(&addend0), Op(&addend1));
+ auto add3_pattern = AnyOf<HloInstruction>(
+ AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0));
+
+ ASSERT_TRUE(Match(root, add3_pattern));
+ EXPECT_NE(nullptr, addend0);
+ EXPECT_NE(nullptr, addend1);
+ EXPECT_EQ(nullptr, addend2);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 178a78ede0..c522e7ae23 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -217,9 +218,12 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
if (platform->id() == se::host::kHostPlatformId) {
// On host "devices", StreamExecutor exports a device for each hardware
// thread. Because we parallelize a single computation across threads, it
- // doesn't make sense to expose these as separate devices, so fix the number
- // of devices to one.
- device_count = 1;
+ // doesn't make sense to expose these as separate devices, so by default we
+ // fix the number of devices to one. However we do let the user override
+ // this behavior to help run tests on the host that run models in parallel
+ // across multiple devices.
+ device_count = legacy_flags::GetDebugOptionsFromFlags()
+ .xla_force_host_platform_device_count();
}
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
VLOG(1) << "Initializing devices";
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index 256b231e3a..4bb22428f3 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -29,7 +29,7 @@ namespace xla {
// HLO pass which inserts reduce-precision instructions into the HLO graph, for
// purposes of experimenting with the effects of reduced-precision storage of
// intermediate values.
-class ReducePrecisionInsertion : public HloPassInterface {
+class ReducePrecisionInsertion : public HloModulePass {
using InstructionFilterFunction = std::function<bool(const HloInstruction*)>;
public:
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
index 1e86a0823a..a3db439e34 100644
--- a/tensorflow/compiler/xla/service/reshape_mover.h
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -24,7 +24,7 @@ namespace xla {
// This now only moves them outputward across elementwise ops all whose operands
// are equivalent Reshapes or Transposes, but in future could potentially move
// them inputward also.
-class ReshapeMover : public HloPassInterface {
+class ReshapeMover : public HloModulePass {
public:
absl::string_view name() const override { return "reshape-mover"; }
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
index 2f4b2667c4..de7aee262e 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.cc
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -155,6 +155,53 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
}
+static StatusOr<HloInstruction*> CheckIndexValidity(
+ HloComputation* computation, HloInstruction* index,
+ absl::Span<const int64> operand_dims, absl::Span<const int64> window_sizes,
+ HloModule* module) {
+ DCHECK_NE(nullptr, module);
+ DCHECK_EQ(operand_dims.size(), window_sizes.size());
+
+ // Valid range for the index: [0, operand_dims - window_sizes]
+
+ // Check if the index has any negative values.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * zero_index,
+ BroadcastZeros(computation, index->shape().element_type(),
+ AsInt64Slice(index->shape().dimensions())));
+ TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check,
+ MakeBinaryHlo(HloOpcode::kLe, zero_index, index));
+
+ // Check if the index is OOB w.r.t. the operand dimensions and window sizes.
+ std::vector<int64> max_valid_index(operand_dims.size());
+ for (int i = 0; i < operand_dims.size(); ++i) {
+ max_valid_index[i] = operand_dims[i] - window_sizes[i];
+ }
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * max_valid_index_constant,
+ MakeR1ConstantHlo<int64>(computation, index->shape().element_type(),
+ max_valid_index));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * oob_index_check,
+ MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index));
+
+ // Combine the results of the two checks above.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * valid_index,
+ MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check));
+
+ // Reduce the index validity check vector into a scalar predicate.
+ auto reduction_init = computation->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * valid_index_reduced,
+ MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module));
+
+ // Return a broadcasted value of the scalar predicate to the same size as the
+ // window.
+ return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes);
+}
+
// Body of the while loop that performs the scatter operation using other HLOs.
static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
HloInstruction* scatter, HloInstruction* induction_var,
@@ -222,7 +269,16 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
InsertDegenerateDims(update_slice_for_scatter,
AsInt64Slice(dim_numbers.inserted_window_dims())));
- // Extact the slice to update from `operand` tensor.
+ // Note that the following transformation assumes that both DynamicSlice and
+ // DynamicUpdateSlice follow the same semantics for OOB indices. For example,
+ // if there are negative indices and DynamicSlice uses "clamping" semantics,
+ // then the extracted data will be "shifted". Since DynamicUpdateSlice also
+ // follows the same "clamping" semantics, writing the update will also be
+ // "shifted" by exactly the same amount. So, this transformation is correct as
+ // long as the semantics of handling OOB indices remain the same in
+ // DynamicSlice and DynamicUpdateSlice.
+
+ // Extract the slice to update from `operand` tensor.
const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
TF_ASSIGN_OR_RETURN(
HloInstruction * operand_slice_to_update,
@@ -237,10 +293,24 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted},
scatter->to_apply()));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * is_index_valid,
+ CheckIndexValidity(
+ operand->parent(), scatter_slice_start,
+ AsInt64Slice(operand->shape().dimensions()),
+ AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()),
+ scatter->GetModule()));
+
+ // Select the updated operand only if the index is valid. If not, select the
+ // original value.
+ TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply,
+ MakeSelectHlo(is_index_valid, updated_operand_slice,
+ operand_slice_to_update));
+
// Write the updated value of the slice into `operand` tensor.
- TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand,
- MakeDynamicUpdateSliceHlo(operand, updated_operand_slice,
- scatter_slice_start));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * updated_operand,
+ MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start));
return StatusOr<std::vector<HloInstruction*>>{
{updated_operand, scatter_indices, updates}};
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
index 14f062c89c..559a85dccf 100644
--- a/tensorflow/compiler/xla/service/scatter_expander.h
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -20,7 +20,7 @@ limitations under the License.
namespace xla {
-class ScatterExpander : public HloPassInterface {
+class ScatterExpander : public HloModulePass {
public:
absl::string_view name() const override { return "scatter_expander"; }
StatusOr<bool> Run(HloModule* module) override;
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 922ebdf0e3..b27a92f2a0 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -812,7 +812,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
HloModule::CreateFromProto(module_proto, *module_config));
- TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
TF_ASSIGN_OR_RETURN(
module, backend->compiler()->RunHloPasses(std::move(module), executor,
@@ -1160,7 +1160,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
return replicas;
}
-Status Service::MaybeDumpHloModule(const HloModule& module) const {
+Status Service::MaybeDumpUnoptimizedHloModule(const HloModule& module) const {
const string xla_dump_unoptimized_hlo_proto_to =
module.config().debug_options().xla_dump_unoptimized_hlo_proto_to();
if (xla_dump_unoptimized_hlo_proto_to.empty()) {
@@ -1168,7 +1168,8 @@ Status Service::MaybeDumpHloModule(const HloModule& module) const {
}
HloProto proto = MakeHloProto(module);
return protobuf_util::DumpProtoToDirectory(
- proto, xla_dump_unoptimized_hlo_proto_to, module.name());
+ proto, xla_dump_unoptimized_hlo_proto_to,
+ StrCat(module.name(), ".unoptimized"));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 44c5248b15..1f62fad4c8 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -271,7 +271,9 @@ class Service : public ServiceInterface {
StatusOr<std::vector<se::StreamExecutor*>> Replicas(
const Backend& backend, const DeviceHandle& device_handle) const;
- Status MaybeDumpHloModule(const HloModule& module) const;
+ // Dumps the (unoptimized) module given if the corresponding DebugOptions
+ // field has been set.
+ Status MaybeDumpUnoptimizedHloModule(const HloModule& module) const;
// Returns the device handle that represents the replicated device for a
// single computation that is not model-parallelized.
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 74bdf2a2e3..7194b2cafd 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1665,10 +1665,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (input_features != kernel_input_features * feature_group_count) {
return InvalidArgument(
"Expected LHS feature dimension (value %d) to match RHS "
- "input feature dimension * feature_group_count (value %d); "
+ "input feature dimension * feature_group_count (value %d * %d = %d); "
"got <conv>(%s, %s)\n"
"Dimension numbers: {%s}.",
- input_features, kernel_input_features * feature_group_count,
+ input_features, kernel_input_features, feature_group_count,
+ kernel_input_features * feature_group_count,
ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs),
dnums.DebugString());
}
diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc
deleted file mode 100644
index dd53c7531b..0000000000
--- a/tensorflow/compiler/xla/service/source_map_util.cc
+++ /dev/null
@@ -1,66 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/source_map_util.h"
-
-#include "absl/strings/str_format.h"
-#include "tensorflow/compiler/xla/util.h"
-
-namespace xla {
-namespace source_map_util {
-namespace {
-
-Status InvalidParameterArgumentV(const OpMetadata& op_metadata,
- const char* format, va_list args) {
- string message;
- tensorflow::strings::Appendv(&message, format, args);
- if (!op_metadata.source_file().empty()) {
- absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(),
- op_metadata.source_line());
- }
- return InvalidArgument("%s", message);
-}
-
-} // namespace
-
-Status InvalidParameterArgument(const OpMetadata& op_metadata,
- const char* format, ...) {
- va_list args;
- va_start(args, format);
- Status result = InvalidParameterArgumentV(op_metadata, format, args);
- va_end(args);
- return result;
-}
-
-Status InvalidParameterArgument(Executable* executable, int parameter_number,
- const char* format, ...) {
- va_list args;
- va_start(args, format);
- if (executable != nullptr && executable->has_module()) {
- const HloModule& module = executable->module();
- const HloComputation& computation = *module.entry_computation();
- HloInstruction* param = computation.parameter_instruction(parameter_number);
- const OpMetadata& metadata = param->metadata();
- Status result = InvalidParameterArgumentV(metadata, format, args);
- va_end(args);
- return result;
- }
- Status result = InvalidArgumentV(format, args);
- va_end(args);
- return result;
-}
-
-} // namespace source_map_util
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc
index 5d1cd1c442..ec09dff924 100644
--- a/tensorflow/compiler/xla/service/stream_pool.cc
+++ b/tensorflow/compiler/xla/service/stream_pool.cc
@@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Re-use an existing stream from the pool.
stream = std::move(streams_.back());
streams_.pop_back();
- VLOG(1) << stream->DebugStreamPointers()
- << " StreamPool reusing existing stream";
+ if (stream->ok()) {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " StreamPool reusing existing stream";
+ } else {
+ VLOG(1) << stream->DebugStreamPointers()
+ << " stream was not ok, StreamPool deleting";
+ stream = nullptr;
+ }
}
}
diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc
index aaf5c37b0d..92f47579d3 100644
--- a/tensorflow/compiler/xla/service/stream_pool_test.cc
+++ b/tensorflow/compiler/xla/service/stream_pool_test.cc
@@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) {
EXPECT_EQ(stream2_ptr, stream3_ptr);
}
+TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) {
+ std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
+ StreamPool pool;
+
+ // Borrow a stream.
+ StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream1->ok());
+
+ // Return the stream, but hold a handle to it.
+ se::Stream* stream1_ptr = stream1.get();
+ stream1 = nullptr;
+
+ // Now stream1 is back in the pool, force an error on the stream. Here we call
+ // a method that requires DNN support, which we know the Host platform doesn't
+ // support.
+ stream1_ptr->ThenDepthConcatenate({}, {}, nullptr);
+ EXPECT_FALSE(stream1_ptr->ok());
+
+ // Borrow stream2.
+ StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
+ EXPECT_TRUE(stream2->ok());
+
+ // The underlying streams should be different. They would have been
+ // the same, but since we forced an error on stream1, it cannot be
+ // put back into the pool. Sadly we can't just check:
+ // EXPECT_NE(stream1_ptr, stream2_ptr);
+ //
+ // The above should hold logically, but it may fail if the new
+ // stream instance allocated for stream2 happens to reside in the
+ // same memory address as stream1, which has been deleted.
+ //
+ // The check that stream2->ok() serves as a good-enough check.
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
index 3e5aa2db60..f95f982eb8 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.h
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -23,7 +23,7 @@ namespace xla {
// HLO pass that folds transpose operators into Dot operators, where the Dot
// operator is implemented by a GEMM kernel that can transpose its inputs.
-class TransposeFolding : public HloPassInterface {
+class TransposeFolding : public HloModulePass {
public:
using OperandIndices = std::vector<int64>;
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h
index 8c91d6e69d..e126a53023 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.h
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.h
@@ -25,7 +25,7 @@ namespace xla {
// A pass which simplifies patterns of Tuple and GetTupleElement instructions in
// the module.
-class TupleSimplifier : public HloPassInterface {
+class TupleSimplifier : public HloModulePass {
public:
TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {}
explicit TupleSimplifier(bool exclude_entry_computation);
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
index 2dba7d7f75..577bad6c70 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h
@@ -50,7 +50,7 @@ namespace xla {
// conditions as well.
//
// TODO(b/79121449): We should also sink broadcasts of constants.
-class WhileLoopConstantSinking : public HloPassInterface {
+class WhileLoopConstantSinking : public HloModulePass {
public:
~WhileLoopConstantSinking() override = default;
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
index 2cdf20ce80..3031899f71 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h
@@ -25,7 +25,7 @@ namespace xla {
// HLO pass that rewrites while loops to hoist loop invariant instructions in
// the while body into the computation that contains the while instruction.
-class WhileLoopInvariantCodeMotion : public HloPassInterface {
+class WhileLoopInvariantCodeMotion : public HloModulePass {
public:
// If `hoist_constants` is true then constants are always hoisted out of while
// loop bodies. Otherwise they are only hoisted out if they enable other
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 6a7bfe3f12..9a74f22395 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -252,7 +252,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// Create the new while condition, body, and init value.
std::unique_ptr<HloComputation> new_while_cond =
while_cond->CloneWithReplacements(
- make_while_computation_replacements(while_cond));
+ make_while_computation_replacements(while_cond), /*extras=*/{});
std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
while_body_replacements = make_while_computation_replacements(while_body);
@@ -265,7 +265,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
while_body_replacements.emplace(
while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems));
std::unique_ptr<HloComputation> new_while_body =
- while_body->CloneWithReplacements(std::move(while_body_replacements));
+ while_body->CloneWithReplacements(std::move(while_body_replacements),
+ /*extras=*/{});
// Add a new while_init instruction that repackages the old while_init
// instruction's elements. We rely on the AlgebraicSimplifier and DCE to
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h
index 78024f14dc..0bc5a0107b 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.h
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h
@@ -30,7 +30,7 @@ namespace xla {
// - Elements of a while loop's tuple that the loop doesn't use are removed
// from the tuple.
//
-class WhileLoopSimplifier : public HloPassInterface {
+class WhileLoopSimplifier : public HloModulePass {
public:
~WhileLoopSimplifier() override {}
absl::string_view name() const override { return "simplify-while-loops"; }
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
index a7f0e207eb..87294120d5 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h
@@ -21,7 +21,7 @@ limitations under the License.
// HLO pass that replaces zero sized Hlos with a zero sized constant literal.
namespace xla {
-class ZeroSizedHloElimination : public HloPassInterface {
+class ZeroSizedHloElimination : public HloModulePass {
public:
StatusOr<bool> Run(HloModule* module) override;
absl::string_view name() const override {
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 9772c06bce..020c167ee9 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -422,8 +422,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
- CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
- CHECK_EQ(shape.dimensions_size(), Rank(shape));
+ DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
+ DCHECK_EQ(shape.dimensions_size(), Rank(shape));
+ if (shape.dimensions().size() == 1) {
+ return shape.dimensions()[0];
+ }
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
std::multiplies<int64>());
@@ -441,6 +444,19 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return count;
}
+/* static */ bool ShapeUtil::HasPrimitiveType(const Shape& shape,
+ PrimitiveType primitive_type) {
+ if (shape.element_type() == primitive_type) {
+ return true;
+ }
+ for (const Shape& element_shape : shape.tuple_shapes()) {
+ if (HasPrimitiveType(element_shape, primitive_type)) {
+ return true;
+ }
+ }
+ return false;
+}
+
/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 8234fcdd3f..d8bb27beae 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <initializer_list>
#include <string>
+#include "absl/base/macros.h"
#include "absl/container/inlined_vector.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
@@ -180,6 +181,10 @@ class ShapeUtil {
// As ElementsIn(), but recurses through tuples.
static int64 ElementsInRecursive(const Shape& shape);
+ // Returns true if shape has the primitive type, recurses through tuples.
+ static bool HasPrimitiveType(const Shape& shape,
+ PrimitiveType primitive_type);
+
// Returns true if 'shape' is an array with zero elements.
static bool IsZeroElementArray(const Shape& shape);
@@ -475,8 +480,7 @@ class ShapeUtil {
// Shorthand for testing whether a shape is of a given element type and
// sequence of dimensions.
- //
- // DEPRECATED: Use Equal() instead.
+ ABSL_DEPRECATED("Use Equal() instead.")
static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
std::initializer_list<int64> dimensions);
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 6ca4085aaf..c622ecdca1 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -445,6 +445,22 @@ TEST(ShapeUtilTest, ElementsIn) {
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
}
+TEST(ShapeUtilTest, HasPrimitiveType) {
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S32));
+ EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {}), S16));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeShape(S32, {0}), S32));
+ EXPECT_FALSE(ShapeUtil::HasPrimitiveType(ShapeUtil::MakeTupleShape({}), S32));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}),
+ S32));
+ EXPECT_TRUE(ShapeUtil::HasPrimitiveType(
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S16, {})})}),
+ S16));
+}
+
TEST(ShapeUtilTest, IsZeroElementArray) {
EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index d0bda45cf8..f474ecb18c 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -29,6 +29,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites"
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
# Generate test_suites for all backends, named "${backend}_tests".
generate_backend_suites()
@@ -150,11 +154,31 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/core:lib",
- "//tensorflow/core:test",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
],
)
+tf_cc_test(
+ name = "hlo_verified_test_base_test",
+ srcs = ["hlo_verified_test_base_test.cc"],
+ deps = [
+ ":hlo_test_base",
+ ":hlo_verified_test_base",
+ ":test_macros_cpu",
+ ":test_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
tf_cc_binary(
name = "local_client_aot_test_helper",
srcs = ["local_client_aot_test_helper.cc"],
@@ -647,6 +671,7 @@ xla_test(
],
shard_count = 48,
tags = [
+ "broken",
"manual",
"notap",
],
@@ -1796,7 +1821,7 @@ xla_test(
tf_cc_test(
name = "llvm_compiler_test",
srcs = ["llvm_compiler_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:test_helpers",
@@ -2095,7 +2120,7 @@ tf_cc_test(
name = "sample_file_test",
srcs = ["sample_file_test.cc"],
data = ["isolated_convolution.hlo"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":hlo_test_base",
"//tensorflow/compiler/xla:test",
@@ -2143,3 +2168,21 @@ xla_test(
"//tensorflow/core:lib",
],
)
+
+tf_cc_test(
+ name = "multiple_devices_on_host_test",
+ srcs = ["multiple_devices_on_host_test.cc"],
+ args = ["--xla_force_host_platform_device_count=4"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "@com_google_absl//absl/synchronization",
+ ],
+)
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index 53f2c3bfbf..05d4d04034 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -3,256 +3,266 @@
load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
all_backends = ["cpu", "gpu"] + plugins.keys()
def filter_backends(backends):
- """Removes "gpu" from a backend list if CUDA is not enabled.
-
- This allows us to simply hardcode lists including "gpu" here and in the
- BUILD file, without causing failures when CUDA isn't enabled.'
-
- Args:
- backends: A list of backends to filter.
-
- Returns:
- The filtered list of backends.
- """
- if cuda_is_configured():
- return backends
- else:
- return [backend for backend in backends if backend != "gpu"]
-
-
-def xla_test(name,
- srcs,
- deps,
- xla_test_library_deps=[],
- backends=[],
- blacklisted_backends=[],
- args=[],
- tags=[],
- copts=[],
- data=[],
- backend_tags={},
- backend_args={},
- **kwargs):
- """Generates cc_test targets for the given XLA backends.
-
- This rule generates a cc_test target for one or more XLA backends and also a
- platform-agnostic cc_library rule. The arguments are identical to cc_test with
- two additions: 'backends' and 'backend_args'. 'backends' specifies the
- backends to generate tests for ("cpu", "gpu"), and
- 'backend_args'/'backend_tags' specifies backend-specific args parameters to
- use when generating the cc_test.
-
- The name of the cc_tests are the provided name argument with the backend name
- appended, and the cc_library target name is the provided name argument with
- "_lib" appended. For example, if name parameter is "foo_test", then the cpu
- test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
-
- The cc_library target can be used to link with other plugins outside of
- xla_test.
-
- The build rule also defines a test suite ${name} which includes the tests for
- each of the supported backends.
-
- Each generated cc_test target has a tag indicating which backend the test is
- for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
- tags can be used to gather tests for a particular backend into a test_suite.
-
- Examples:
-
- # Generates the targets: foo_test_cpu and foo_test_gpu.
- xla_test(
- name = "foo_test",
- srcs = ["foo_test.cc"],
- backends = ["cpu", "gpu"],
- deps = [...],
- )
+ """Removes "gpu" from a backend list if CUDA is not enabled.
- # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
- # includes the additional arg "--special_cpu_flag".
- xla_test(
- name = "bar_test",
- srcs = ["bar_test.cc"],
- backends = ["cpu", "gpu"],
- backend_args = {"cpu": ["--special_cpu_flag"]}
- deps = [...],
- )
+ This allows us to simply hardcode lists including "gpu" here and in the
+ BUILD file, without causing failures when CUDA isn't enabled.'
- The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
- to the value 1 where ${BACKEND} is the uppercase name of the backend.
-
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- deps: Dependencies of the target.
- xla_test_library_deps: If set, the generated test targets will depend on the
- respective cc_libraries generated by the xla_test_library rule.
- backends: A list of backends to generate tests for. Supported values: "cpu",
- "gpu". If this list is empty, the test will be generated for all supported
- backends.
- blacklisted_backends: A list of backends to NOT generate tests for.
- args: Test arguments for the target.
- tags: Tags for the target.
- copts: Additional copts to pass to the build.
- data: Additional data to pass to the build.
- backend_tags: A dict mapping backend name to list of additional tags to
- use for that target.
- backend_args: A dict mapping backend name to list of additional args to
- use for that target.
- **kwargs: Additional keyword arguments to pass to native.cc_test.
- """
- test_names = []
- if not backends:
- backends = all_backends
-
- backends = [backend for backend in backends
- if backend not in blacklisted_backends]
-
- native.cc_library(
- name="%s_lib" % name,
- srcs=srcs,
- copts=copts,
- testonly=True,
- deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
- )
-
- for backend in filter_backends(backends):
- test_name = "%s_%s" % (name, backend)
- this_backend_tags = ["xla_%s" % backend]
- this_backend_copts = []
- this_backend_args = backend_args.get(backend, [])
- this_backend_data = []
- if backend == "cpu":
- backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
- elif backend == "gpu":
- backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
- this_backend_tags += ["requires-gpu-sm35"]
- elif backend in plugins:
- backend_deps = []
- backend_deps += plugins[backend]["deps"]
- this_backend_copts += plugins[backend]["copts"]
- this_backend_tags += plugins[backend]["tags"]
- this_backend_args += plugins[backend]["args"]
- this_backend_data += plugins[backend]["data"]
- else:
- fail("Unknown backend %s" % backend)
-
- if xla_test_library_deps:
- for lib_dep in xla_test_library_deps:
- backend_deps += ["%s_%s" % (lib_dep, backend)]
-
- tf_cc_test(
- name=test_name,
- srcs=srcs,
- tags=tags + backend_tags.get(backend, []) + this_backend_tags,
- extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
- this_backend_copts,
- args=args + this_backend_args,
- deps=deps + backend_deps,
- data=data + this_backend_data,
- **kwargs)
-
- test_names.append(test_name)
-
- native.test_suite(name=name, tests=test_names)
-
-def xla_test_library(name,
- srcs,
- hdrs=[],
- deps=[],
- backends=[]):
- """Generates cc_library targets for the given XLA backends.
-
- This rule forces the sources to be compiled for each backend so that the
- backend specific macros could expand correctly. It's useful when test targets
- in different directories referring to the same sources but test with different
- arguments.
-
- Examples:
-
- # Generates the targets: foo_test_library_cpu and foo_test_gpu.
- xla_test_library(
- name = "foo_test_library",
- srcs = ["foo_test.cc"],
- backends = ["cpu", "gpu"],
- deps = [...],
- )
- # Then use the xla_test rule to generate test targets:
- xla_test(
- name = "foo_test",
- srcs = [],
- backends = ["cpu", "gpu"],
- deps = [...],
- xla_test_library_deps = [":foo_test_library"],
- )
+ Args:
+ backends: A list of backends to filter.
- Args:
- name: Name of the target.
- srcs: Sources for the target.
- hdrs: Headers for the target.
- deps: Dependencies of the target.
- backends: A list of backends to generate libraries for.
- Supported values: "cpu", "gpu". If this list is empty, the
- library will be generated for all supported backends.
- """
-
- if not backends:
- backends = all_backends
-
- for backend in filter_backends(backends):
- this_backend_copts = []
- if backend in ["cpu", "gpu"]:
- backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
- elif backend in plugins:
- backend_deps = plugins[backend]["deps"]
- this_backend_copts += plugins[backend]["copts"]
+ Returns:
+ The filtered list of backends.
+ """
+ if cuda_is_configured():
+ return backends
else:
- fail("Unknown backend %s" % backend)
+ return [backend for backend in backends if backend != "gpu"]
+
+def xla_test(
+ name,
+ srcs,
+ deps,
+ xla_test_library_deps = [],
+ backends = [],
+ blacklisted_backends = [],
+ args = [],
+ tags = [],
+ copts = [],
+ data = [],
+ backend_tags = {},
+ backend_args = {},
+ **kwargs):
+ """Generates cc_test targets for the given XLA backends.
+
+ This rule generates a cc_test target for one or more XLA backends and also a
+ platform-agnostic cc_library rule. The arguments are identical to cc_test with
+ two additions: 'backends' and 'backend_args'. 'backends' specifies the
+ backends to generate tests for ("cpu", "gpu"), and
+ 'backend_args'/'backend_tags' specifies backend-specific args parameters to
+ use when generating the cc_test.
+
+ The name of the cc_tests are the provided name argument with the backend name
+ appended, and the cc_library target name is the provided name argument with
+ "_lib" appended. For example, if name parameter is "foo_test", then the cpu
+ test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
+
+ The cc_library target can be used to link with other plugins outside of
+ xla_test.
+
+ The build rule also defines a test suite ${name} which includes the tests for
+ each of the supported backends.
+
+ Each generated cc_test target has a tag indicating which backend the test is
+ for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
+ tags can be used to gather tests for a particular backend into a test_suite.
+
+ Examples:
+
+ # Generates the targets: foo_test_cpu and foo_test_gpu.
+ xla_test(
+ name = "foo_test",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+
+ # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
+ # includes the additional arg "--special_cpu_flag".
+ xla_test(
+ name = "bar_test",
+ srcs = ["bar_test.cc"],
+ backends = ["cpu", "gpu"],
+ backend_args = {"cpu": ["--special_cpu_flag"]}
+ deps = [...],
+ )
+
+ The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
+ to the value 1 where ${BACKEND} is the uppercase name of the backend.
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ xla_test_library_deps: If set, the generated test targets will depend on the
+ respective cc_libraries generated by the xla_test_library rule.
+ backends: A list of backends to generate tests for. Supported values: "cpu",
+ "gpu". If this list is empty, the test will be generated for all supported
+ backends.
+ blacklisted_backends: A list of backends to NOT generate tests for.
+ args: Test arguments for the target.
+ tags: Tags for the target.
+ copts: Additional copts to pass to the build.
+ data: Additional data to pass to the build.
+ backend_tags: A dict mapping backend name to list of additional tags to
+ use for that target.
+ backend_args: A dict mapping backend name to list of additional args to
+ use for that target.
+ **kwargs: Additional keyword arguments to pass to native.cc_test.
+ """
+ test_names = []
+ if not backends:
+ backends = all_backends
+
+ backends = [
+ backend
+ for backend in backends
+ if backend not in blacklisted_backends
+ ]
native.cc_library(
- name = "%s_%s" % (name, backend),
+ name = "%s_lib" % name,
srcs = srcs,
+ copts = copts,
testonly = True,
- hdrs = hdrs,
- copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()]
- + this_backend_copts,
- deps = deps + backend_deps,
+ deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
)
-
-def generate_backend_suites(backends=[]):
- if not backends:
- backends = all_backends
- for backend in filter_backends(backends):
- native.test_suite(name="%s_tests" % backend,
- tags = ["xla_%s" % backend])
-
-
-def generate_backend_test_macros(backends=[]):
- if not backends:
- backends = all_backends
- for backend in filter_backends(backends):
- manifest = ""
- if backend in plugins:
- manifest = plugins[backend]["disabled_manifest"]
-
- native.cc_library(
- name="test_macros_%s" % backend,
- testonly = True,
- srcs = ["test_macros.cc"],
- hdrs = ["test_macros.h"],
- copts = [
- "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
- "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
- ],
- deps = [
- "//tensorflow/compiler/xla:types",
- "//tensorflow/core:lib",
- "//tensorflow/core:regexp_internal",
- "//tensorflow/core:test",
- ])
+ for backend in filter_backends(backends):
+ test_name = "%s_%s" % (name, backend)
+ this_backend_tags = ["xla_%s" % backend]
+ this_backend_copts = []
+ this_backend_args = backend_args.get(backend, [])
+ this_backend_data = []
+ if backend == "cpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+ elif backend == "gpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
+ this_backend_tags += tf_cuda_tests_tags()
+ elif backend in plugins:
+ backend_deps = []
+ backend_deps += plugins[backend]["deps"]
+ this_backend_copts += plugins[backend]["copts"]
+ this_backend_tags += plugins[backend]["tags"]
+ this_backend_args += plugins[backend]["args"]
+ this_backend_data += plugins[backend]["data"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ if xla_test_library_deps:
+ for lib_dep in xla_test_library_deps:
+ backend_deps += ["%s_%s" % (lib_dep, backend)]
+
+ tf_cc_test(
+ name = test_name,
+ srcs = srcs,
+ tags = tags + backend_tags.get(backend, []) + this_backend_tags,
+ extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ args = args + this_backend_args,
+ deps = deps + backend_deps,
+ data = data + this_backend_data,
+ **kwargs
+ )
+
+ test_names.append(test_name)
+
+ native.test_suite(name = name, tests = test_names)
+
+def xla_test_library(
+ name,
+ srcs,
+ hdrs = [],
+ deps = [],
+ backends = []):
+ """Generates cc_library targets for the given XLA backends.
+
+ This rule forces the sources to be compiled for each backend so that the
+ backend specific macros could expand correctly. It's useful when test targets
+ in different directories referring to the same sources but test with different
+ arguments.
+
+ Examples:
+
+ # Generates the targets: foo_test_library_cpu and foo_test_gpu.
+ xla_test_library(
+ name = "foo_test_library",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+ # Then use the xla_test rule to generate test targets:
+ xla_test(
+ name = "foo_test",
+ srcs = [],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ xla_test_library_deps = [":foo_test_library"],
+ )
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ hdrs: Headers for the target.
+ deps: Dependencies of the target.
+ backends: A list of backends to generate libraries for.
+ Supported values: "cpu", "gpu". If this list is empty, the
+ library will be generated for all supported backends.
+ """
+
+ if not backends:
+ backends = all_backends
+
+ for backend in filter_backends(backends):
+ this_backend_copts = []
+ if backend in ["cpu", "gpu"]:
+ backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
+ elif backend in plugins:
+ backend_deps = plugins[backend]["deps"]
+ this_backend_copts += plugins[backend]["copts"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ native.cc_library(
+ name = "%s_%s" % (name, backend),
+ srcs = srcs,
+ testonly = True,
+ hdrs = hdrs,
+ copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ deps = deps + backend_deps,
+ )
+
+def generate_backend_suites(backends = []):
+ if not backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
+ native.test_suite(
+ name = "%s_tests" % backend,
+ tags = ["xla_%s" % backend, "-broken", "manual"],
+ )
+
+def generate_backend_test_macros(backends = []):
+ if not backends:
+ backends = all_backends
+ for backend in filter_backends(backends):
+ manifest = ""
+ if backend in plugins:
+ manifest = plugins[backend]["disabled_manifest"]
+
+ native.cc_library(
+ name = "test_macros_%s" % backend,
+ testonly = True,
+ srcs = ["test_macros.cc"],
+ hdrs = ["test_macros.h"],
+ copts = [
+ "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
+ "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ "//tensorflow/core:test",
+ ],
+ )
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index 738f2600d4..51b50d456e 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -45,22 +45,22 @@ class ExhaustiveF32ElementwiseOpTest
i < known_incorrect_range.second) {
// If the operation is known to be buggy on a specific input clamp that
// input to 0 under the assumption that the op is at least correct on 0.
- input_literal->Set({i - begin}, 0.0f);
+ input_literal.Set({i - begin}, 0.0f);
} else {
- input_literal->Set({i - begin}, tensorflow::bit_cast<float, int>(i));
+ input_literal.Set({i - begin}, tensorflow::bit_cast<float, int>(i));
}
}
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
- client_->TransferToServer(*input_literal));
+ client_->TransferToServer(input_literal));
- auto input = Parameter(&builder, 0, input_literal->shape(), "input");
+ auto input = Parameter(&builder, 0, input_literal.shape(), "input");
enqueue_op(&builder, input);
std::vector<float> expected_result;
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
- expected_result.push_back(evaluate_op(input_literal->Get<float>({i})));
+ expected_result.push_back(evaluate_op(input_literal.Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index 8f86c528d0..8bd0a729b7 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -21,64 +21,68 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/test.h"
namespace xla {
-HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
- bool allow_mixed_precision)
- : HloTestBase(
- /*verifier_layout_sensitive=*/layout_sensitive,
- /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {}
-
-HloVerifiedTestBase::~HloVerifiedTestBase() {
- // We can't call the ASSERT or EXPECT test macros in destructors, so we
- // perform HLO verification in TearDown, and use the CHECK here to ensure
- // users don't accidentally override the verification.
- CHECK(tear_down_called_)
- << "TearDown was never called; subclasses of HloVerifiedTestBase that "
- << "override TearDown must call the superclass TearDown.";
-}
-
-void HloVerifiedTestBase::TearDown() {
- EXPECT_FALSE(tear_down_called_)
- << "TearDown called more than once; it should be called exactly once.";
- tear_down_called_ = true;
- if (module_) {
- VerifyModule(module_.get());
+Status VerifiedHloModule::Verify() {
+ if (computation_count() == 0) {
+ // The computation was never built. Nothing to verify.
+ return Status::OK();
}
- for (int i = 0; i < modules_.size(); ++i) {
- VerifyModule(modules_.at(i).get());
- }
- HloTestBase::TearDown();
+ return verifier_.Run(this).status();
}
-void HloVerifiedTestBase::VerifyModule(HloModule* module) {
- xla::StatusOr<bool> mutated = verifier().Run(module);
- if (!mutated.ok()) {
- ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
- } else {
- EXPECT_FALSE(mutated.ValueOrDie())
- << "HloVerifier should never mutate the HloModule";
+void VerifiedHloModule::VerifyOrAddFailure(const string& message) {
+ Status status = Verify();
+ if (!status.ok()) {
+ ADD_FAILURE() << "HloVerifier failed on module " << name()
+ << (message.empty() ? "" : absl::StrCat(" (", message, ")"))
+ << ": " << status;
}
}
+HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive,
+ bool allow_mixed_precision)
+ : HloTestBase(
+ /*verifier_layout_sensitive=*/layout_sensitive,
+ /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision),
+ verifier_layout_sensitive_(layout_sensitive),
+ allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {}
+
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = HloTestBase::CreateNewModule();
+ module_ = CreateNewVerifiedModule(TestName());
}
return *module_;
}
HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
- modules_.emplace_back(HloTestBase::CreateNewModule());
+ modules_.emplace_back(CreateNewVerifiedModule(name));
return modules_.back().get();
}
void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
- TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config));
- VerifyModule(module_.get());
+ module_ = CreateNewVerifiedModule(TestName());
+ TF_CHECK_OK(ParseHloString(hlo_text, module_.get()));
+ module_->VerifyOrAddFailure("after parsing");
}
+
+StatusOr<std::unique_ptr<VerifiedHloModule>>
+HloVerifiedTestBase::ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text, const HloModuleConfig& config) {
+ auto module = CreateNewVerifiedModule(TestName());
+ TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get()));
+ TF_RETURN_IF_ERROR(module->Verify());
+ return std::move(module);
+}
+
+std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule(
+ const string& name) {
+ return absl::make_unique<VerifiedHloModule>(
+ name, GetModuleConfigForTest(), verifier_layout_sensitive_,
+ allow_mixed_precision_in_hlo_verifier_);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index 8fbc4fa753..388a99bb36 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -20,53 +20,84 @@ limitations under the License.
#include <memory>
#include <utility>
+#include "absl/base/macros.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
-// A base class for HLO tests that stores a default HloModule, and automatically
-// performs verification on that module on tear-down.
+// An HLO module derived class which verifies itself on destruction. This class
+// is intended to be used in unit tests. Any verification errors are raised via
+// ADD_FAILURE.
+class VerifiedHloModule : public HloModule {
+ public:
+ VerifiedHloModule(const string& name, const HloModuleConfig& config,
+ bool verifier_layout_sensitive,
+ bool allow_mixed_precision_in_hlo_verifier)
+ : HloModule(name, config),
+ verifier_(verifier_layout_sensitive,
+ allow_mixed_precision_in_hlo_verifier) {}
+
+ ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); }
+
+ // Verifies the module using HloVerifier and returns the status.
+ Status Verify();
+
+ // Verifies the module and flags any error with ADD_FAILURE. 'message' is
+ // included in the failure message.
+ void VerifyOrAddFailure(const string& message);
+
+ private:
+ HloVerifier verifier_;
+};
+
+// A base class for HLO tests that stores a default VerifiedHloModule.
class HloVerifiedTestBase : public HloTestBase {
protected:
- explicit HloVerifiedTestBase(bool layout_sensitive = false,
- bool allow_mixed_precision = false);
- ~HloVerifiedTestBase() override;
+ HloVerifiedTestBase(bool layout_sensitive = false,
+ bool allow_mixed_precision = false);
// Constructs a default shape verifier.
std::unique_ptr<ShapeVerifier> MakeShapeVerifier();
- // Performs verification on the default HloModule returned by module().
- // Automatically called by the testing framework for each test.
- //
- // REQUIRED: subclasses that override TearDown() must call this explicitly.
- void TearDown() override;
-
// Returns the default HloModule, lazily creating it if necessary via
// HloTestBase::CreateNewModule().
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule& module();
+
+ ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.")
void ParseAndVerifyModule(absl::string_view hlo_text,
const HloModuleConfig& config = HloModuleConfig());
+ // Parses the given string and returns module as a VerifiedHloModule.
+ StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
+ absl::string_view hlo_text,
+ const HloModuleConfig& config = HloModuleConfig());
+
// Creates a new module for a test, and stores it in modules_ so it can be
// verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
// creation of unverified modules.
+ ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.")
HloModule* CreateNewModule(const string& name = TestName());
- private:
- void VerifyModule(HloModule* module);
+ // Creates and returns a verified HLO module with the given name.
+ std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
+ const string& name = TestName());
+ private:
// It is confusing to store modules created by module() and CreateNewModule()
// in different fields, but it allows us to migrate tests to
// HloVerifiedTestBase more easily, so it's a win because we can verify more
// modules. See b/80488902.
//
// Lazily populated. Access via module().
- std::unique_ptr<HloModule> module_;
+ std::unique_ptr<VerifiedHloModule> module_;
+
// Populated by calls to CreateNewModule.
- std::vector<std::unique_ptr<HloModule>> modules_;
+ std::vector<std::unique_ptr<VerifiedHloModule>> modules_;
- bool tear_down_called_ = false;
+ bool verifier_layout_sensitive_;
+ bool allow_mixed_precision_in_hlo_verifier_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
new file mode 100644
index 0000000000..5c0263e811
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// This class includes unit tests which are expected to fail because invalid HLO
+// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to
+// include the necessary gunit parts to test this test machinery (needs the
+// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the
+// disabled tests enabled and failures can be manually compared against
+// expectations.
+class HloVerifiedTestBaseTest : public HloVerifiedTestBase {};
+
+XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) {
+ // Test shouldn't fail if no module is created at all.
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) {
+ // Use module() to lazily create an empty module, build it up, and verify no
+ // failures.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) {
+ // Use module() to lazily create an empty module and build up an invalid
+ // module.
+ HloModule& hlo_module = module();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ hlo_module.AddEntryComputation(builder.Build());
+
+ *hlo_module.entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) {
+ // Call CreateNewModule and build up a valid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) {
+ // Call CreateNewModule and build up a invalid module.
+ HloModule* module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input));
+ module->AddEntryComputation(builder.Build());
+
+ *module->entry_computation()->root_instruction()->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {1, 2, 3});
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndVerifyModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ ParseAndVerifyModule(hlo_string);
+ EXPECT_EQ(module().entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(hlo_string));
+ EXPECT_EQ(module->entry_computation()->instruction_count(), 3);
+}
+
+XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleGood
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x,y)
+}
+
+RANDOM GARBAGE
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+// This test is expected to fail. See test class comment.
+XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) {
+ const char* const hlo_string = R"(
+HloModule ParseAndReturnVerifiedModuleBad
+
+ENTRY entry {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[1234] add(x,y)
+}
+)";
+
+ ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc
new file mode 100644
index 0000000000..c530591c6e
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "absl/synchronization/mutex.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+StatusOr<XlaComputation> BuildComputation() {
+ XlaBuilder b("computation");
+ Shape scalar_s32 = ShapeUtil::MakeShape(S32, {});
+ XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32);
+ return b.Build(
+ OutfeedWithToken(GetTupleElement(infeed, 0) +
+ ConstantLiteral(&b, LiteralUtil::CreateR0<int32>(1)),
+ GetTupleElement(infeed, 1), scalar_s32, ""));
+}
+
+void CompileAndExecute(
+ LocalExecutable* executable, int device_ordinal, LocalClient* client,
+ absl::Mutex* results_mutex,
+ std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>>* results) {
+ xla::ExecutableRunOptions execute_options;
+ execute_options.set_intra_op_thread_pool(
+ client->backend().eigen_intra_op_thread_pool_device());
+ execute_options.set_device_ordinal(device_ordinal);
+ execute_options.set_allocator(
+ xla::ClientLibrary::GetXlaService(client->platform())
+ ->backend()
+ .memory_allocator());
+ StatusOr<ScopedShapedBuffer> result = executable->Run({}, execute_options);
+ {
+ absl::MutexLock lock(results_mutex);
+ results->emplace_back(device_ordinal, std::move(result));
+ }
+}
+
+void TestWithDeviceCount(const int device_count) {
+ // Run `device_count` copies of the XLA program built by BuildComputation.
+ TF_ASSERT_OK_AND_ASSIGN(
+ se::Platform* const platform,
+ perftools::gputools::MultiPlatformManager::PlatformWithName("Host"));
+ xla::LocalClientOptions client_options;
+ client_options.set_platform(platform);
+ TF_ASSERT_OK_AND_ASSIGN(
+ LocalClient* const client,
+ xla::ClientLibrary::GetOrCreateLocalClient(client_options));
+
+ TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation());
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<LocalExecutable> executable,
+ client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{}));
+ std::vector<tensorflow::Thread*> threads;
+ absl::Mutex results_mutex;
+ std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>> results;
+ tensorflow::Env* env = tensorflow::Env::Default();
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ tensorflow::Thread* t = env->StartThread(
+ tensorflow::ThreadOptions{}, absl::StrCat("thread-", device_ordinal),
+ [&executable, device_ordinal, client, &results_mutex, &results] {
+ CompileAndExecute(executable.get(), device_ordinal, client,
+ &results_mutex, &results);
+ });
+ threads.push_back(t);
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK(client->TransferToInfeedLocal(
+ LiteralUtil::CreateR0<int32>(device_ordinal * 100), device_ordinal));
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK_AND_ASSIGN(Literal outfeed,
+ client->TransferFromOutfeedLocal(
+ ShapeUtil::MakeShape(S32, {}), device_ordinal));
+ EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ delete threads[device_ordinal];
+ }
+
+ for (int device_ordinal = 0; device_ordinal < device_count;
+ device_ordinal++) {
+ TF_ASSERT_OK(results[device_ordinal].second.status());
+ }
+}
+
+// NB! This test requires --xla_force_host_platform_device_count=4
+
+TEST(MultipleDeviceOnHostTest, OneDevice) { TestWithDeviceCount(1); }
+
+TEST(MultipleDeviceOnHostTest, TwoDevices) { TestWithDeviceCount(2); }
+
+TEST(MultipleDeviceOnHostTest, ThreeDevices) { TestWithDeviceCount(3); }
+
+TEST(MultipleDeviceOnHostTest, FourDevices) { TestWithDeviceCount(4); }
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index d5de9650f1..c25ccafaf8 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -588,7 +588,7 @@ string R4ReduceWindowTestDataToString(
// Test names are not allowed to contain the '-' character.
std::replace(str.begin(), str.end(), '-', 'n');
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -980,7 +980,7 @@ string R3ReduceWindowTestDataToString(
param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -1121,7 +1121,7 @@ string R2ReduceWindowTestDataToString(
param.layout[1], //
"__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
@@ -1303,11 +1303,19 @@ struct R1ReduceWindowTestData {
/*pad_high=*/{0},
/*reducer=*/Reducer::kAdd},
+ // The pattern generated by inclusive scan (cumsum/cumprod).
{/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
/*strides=*/{1},
/*pad_low=*/{4095},
/*pad_high=*/{0},
/*reducer=*/Reducer::kMax},
+
+ // The pattern generated by exclusive scan (cumsum/cumprod).
+ {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
+ /*strides=*/{1},
+ /*pad_low=*/{4096},
+ /*pad_high=*/{0},
+ /*reducer=*/Reducer::kMax},
};
string R1ReduceWindowTestDataToString(
@@ -1322,7 +1330,7 @@ string R1ReduceWindowTestDataToString(
"__pad_high_", absl::StrJoin(param.pad_high, "x"),
"__reducer_", param.reducer == kAdd ? "add" : "max");
if (::testing::get<1>(data.param)) {
- str = absl::StrCat(str, "_bfloat16");
+ absl::StrAppend(&str, "_bfloat16");
}
return str;
}
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
index d20dba028a..b21dd56045 100644
--- a/tensorflow/compiler/xla/tests/scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -507,6 +507,36 @@ ENTRY main {
RunTest(hlo_text, &operand, &scatter_indices, &updates);
}
+XLA_TEST_F(ScatterTest, OutOfBoundsUpdateWindow) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd_OobUpdateWindow
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[1,2] parameter(1)
+ updates = s32[1,2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ Literal operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}});
+ Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}});
+ RunTest(hlo_text, &operand, &scatter_indices, &updates);
+}
+
XLA_TEST_F(ScatterTest, OneScalarIndex) {
const char* hlo_text = R"(
HloModule OneScalarIndex
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index a40c2d7de6..2cc33ab096 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -412,6 +412,7 @@ INSTANTIATE_TEST_CASE_P(
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, //
R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, //
+ R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, //
R2Spec{
511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, //
R2Spec{
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 7abd8651d5..8b1b9e1519 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -763,9 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) {
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
-// Test while nodes that share the while body computation.
-// TODO(b/37245345): Fails on GPU backend.
-TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) {
+TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) {
std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
ShapeUtil::MakeShape(F32, {10})};
Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index b53f89d63b..60d25a6407 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -200,6 +200,15 @@ message DebugOptions {
// among different algorithms.
bool xla_gpu_crash_on_verification_failures = 101;
+ // Force the host platform to pretend that there are these many host
+ // "devices". All these devices are backed by the same threadpool. Defaults
+ // to 1.
+ //
+ // Setting this to anything other than 1 can increase overhead from context
+ // switching but we let the user override this behavior to help run tests on
+ // the host that run models in parallel across multiple devices.
+ int32 xla_force_host_platform_device_count = 102;
+
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index dd329f1181..73b3589dbf 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -351,6 +351,7 @@ message DeviceAssignmentProto {
message LiteralProto {
Shape shape = 1;
repeated bool preds = 2;
+ bytes s8s = 15;
bytes u8s = 3;
repeated int32 s32s = 4;
repeated int64 s64s = 5;
@@ -364,7 +365,7 @@ message LiteralProto {
bytes f16s = 11;
bytes bf16s = 13;
repeated int64 sparse_indices = 14;
- // Next = 15
+ // Next = 16
}
message WindowDimension {
diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD
index 09ab4ed95f..b6dcfc4eb9 100644
--- a/tensorflow/compiler/xrt/tests/BUILD
+++ b/tensorflow/compiler/xrt/tests/BUILD
@@ -8,6 +8,10 @@ package(
)
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
cc_library(
name = "raw_api_test_lib",
@@ -57,7 +61,7 @@ tf_cuda_cc_test(
size = "medium",
srcs = [],
args = ["--xla_test_device=XLA_GPU"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":raw_api_test_lib",
"//tensorflow/compiler/jit:xla_gpu_device",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 798f499870..ae5ca32bcf 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -60,7 +60,6 @@ py_library(
"//tensorflow/contrib/learn",
"//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
"//tensorflow/contrib/libsvm",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lite/python:lite",
@@ -127,11 +126,16 @@ py_library(
}) + if_not_windows_cuda([
"//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
]) + if_not_windows([
- "//tensorflow/contrib/bigtable", # depends on bigtable
- "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows
- "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
- "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
- ]),
+ ]) + select({
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "//tensorflow/contrib/bigtable",
+ "//tensorflow/contrib/cloud:cloud_py",
+ "//tensorflow/contrib/tensorrt:init_py",
+ "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
+ ],
+ }),
)
cc_library(
@@ -166,7 +170,9 @@ cc_library(
"//tensorflow/contrib/kinesis:dataset_kernels",
],
"//conditions:default": [],
- }),
+ }) + if_not_windows([
+ "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
+ ]),
)
cc_library(
@@ -203,5 +209,7 @@ cc_library(
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
],
"//conditions:default": [],
- }),
+ }) + if_not_windows([
+ "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
+ ]),
)
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 9478e42b46..e71b0e0ae3 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -63,7 +63,6 @@ from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import legacy_seq2seq
-from tensorflow.contrib import linalg
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
from tensorflow.contrib import losses
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
index b3f5d92259..9a8f62b986 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py
@@ -149,7 +149,7 @@ class AllReduceTest(test_util.TensorFlowTestCase):
num_devices = num_workers * num_gpus
dev_list = ["/replica:0/task:0/device:CPU:0"
for _ in range(num_devices)]
- with self.test_session():
+ with self.cached_session():
input_tensors = self._buildInitialVars(shape, dev_list)
un_op = lambda x: math_ops.div(
x, constant_op.constant(num_devices, dtype=types_pb2.DT_FLOAT))
diff --git a/tensorflow/contrib/autograph/BUILD b/tensorflow/contrib/autograph/BUILD
index ad700ac4a0..e37ad7a758 100644
--- a/tensorflow/contrib/autograph/BUILD
+++ b/tensorflow/contrib/autograph/BUILD
@@ -21,11 +21,9 @@ py_library(
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
+ # This module is kept for backward compatibility only. To depend on AutoGraph,
+ # use //third_party/tensorflow/python/autograph instead.
deps = [
- "//tensorflow/contrib/autograph/impl",
- "//tensorflow/contrib/autograph/lang",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/utils",
- "//tensorflow/python:util",
+ "//tensorflow/python/autograph",
],
)
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index cc54da4daa..8c277b59e8 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -1,143 +1,9 @@
# AutoGraph
-IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
+**NOTE: As tensorflow.contrib is being
+[deprecated](https://github.com/tensorflow/community/pull/18), AutoGraph is
+moving into TensorFlow core.
-AutoGraph is a Python to TensorFlow compiler.
-
-With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
-
-For example, this Python function:
-
-```
-def f(x):
- if x < 0:
- x = -x
- return x
-```
-
-would be converted to this:
-
-```
-def graph_mode_f(x):
- with tf.name_scope('f'):
-
- def if_true():
- with tf.name_scope('if_true'):
- x_1, = x,
- x_1 = tf.negative(x_1)
- return x_1,
-
- def if_false():
- with tf.name_scope('if_false'):
- x_1, = x,
- return x_1,
- x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
- return x
-```
-
-so you can use it like an op:
-
-```
-with tf.Graph().as_default():
- x = tf.constant(-1.0)
-
- converted_f = autograph.to_graph(f)
- y = converted_f(x)
-
- with tf.Session() as sess:
- print(sess.run(y))
- # Output: 1
-```
-
-# Getting started
-
-Use AutoGraph in one of the following ways, described below:
-
- 1. Annotations (simpler)
- 2. Functional API (more flexible)
-
-To get started, install the latest nightly TensorFlow build:
-
-```shell
-pip install -U tf-nightly
-```
-
-Then import the `autograph` module from `tf.contrib`:
-
-```
-from tensorflow.contrib import autograph as ag
-```
-
-### Related links
-
-Articles:
-
- * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
-
-Interactive notebooks:
-
- * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
- * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
- * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb)
- * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb)
- * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb)
- * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb)
- * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb)
-
-## Using with annotations
-
-Annotating a function or class with `@convert` converts it in place:
-
-```
-@ag.convert()
-def f(x):
- if x < 0:
- x = -x
- return x
-```
-
-... so that it always outputs TensorFlow code:
-
-```
-with tf.Graph().as_default():
- x = tf.constant(-1)
-
- y = f(x)
-
- with tf.Session() as sess:
- print(sess.run(y))
- # Output: 1
-```
-
-## Using the functional API
-
-The functional API allows you to convert an existing function, class or object after it was defined:
-
-```
-converted_f = ag.to_graph(f)
-
-print(converted_f(tf.constant(-1)))
-# Output: Tensor
-
-print(f(-1))
-# Output: 1
-```
-
-You can use the functional API to inspect the generated code as well:
-
-```
-print(ag.to_code(f))
-# Output: <Python and TensorFlow code>
-```
-
-## Filing bugs and feature requests
-
-### Reporting a bug
-
- - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
- - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
- - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
-
-### Requesting a feature
-
-If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
+The new code location is `tensorflow/python/autograph`. Please refer to the
+README.md file in that directory.
+**
diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py
index 26e7a4a4d3..137bc59202 100644
--- a/tensorflow/contrib/autograph/__init__.py
+++ b/tensorflow/contrib/autograph/__init__.py
@@ -12,57 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Autograph compiles Python code into equivalent TensorFlow code.
+"""This is the legacy module for AutoGraph, kept for backward compatibility.
-Equivalent here means that they have the same effect when executed.
+New users should instead use `tensorflow.python.autograph`.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# TODO(mdan): Bring only the relevant symbols to the top level.
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core.errors import GraphConstructionError
-from tensorflow.contrib.autograph.core.errors import TfRuntimeError
-from tensorflow.contrib.autograph.core.errors import improved_errors
-from tensorflow.contrib.autograph.impl.api import RunMode
-from tensorflow.contrib.autograph.impl.api import convert
-from tensorflow.contrib.autograph.impl.api import converted_call
-from tensorflow.contrib.autograph.impl.api import do_not_convert
-from tensorflow.contrib.autograph.impl.api import to_code
-from tensorflow.contrib.autograph.impl.api import to_graph
-from tensorflow.contrib.autograph.lang.directives import set_element_type
-from tensorflow.contrib.autograph.lang.directives import set_loop_options
-from tensorflow.contrib.autograph.lang.special_functions import stack
-from tensorflow.contrib.autograph.lang.special_functions import tensor_list
-from tensorflow.contrib.autograph.pyct.transformer import AutographParseError
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- # Main API
- 'RunMode',
- 'convert',
- 'converted_call',
- 'do_not_convert',
- 'to_code',
- 'to_graph',
- # Overloaded operators
- 'operators',
- # Errors
- 'improved_errors',
- 'GraphConstructionError',
- 'TfRuntimeError',
- # Python language "extensions"
- 'set_element_type',
- 'set_loop_options',
- 'stack',
- 'tensor_list',
- # Exceptions
- 'AutographParseError',
- # Utilities: to be removed
- 'utils',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
+from tensorflow.python.autograph import * # pylint:disable=wildcard-import
diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py
deleted file mode 100644
index 38e0a0a8f0..0000000000
--- a/tensorflow/contrib/autograph/utils/__init__.py
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility module that contains APIs usable in the generated code."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
-from tensorflow.contrib.autograph.utils.misc import alias_tensors
-from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
-from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not
-from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond
-from tensorflow.contrib.autograph.utils.py_func import wrap_py_func
-from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append
-from tensorflow.contrib.autograph.utils.testing import fake_tf
-from tensorflow.contrib.autograph.utils.type_check import is_tensor
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index 7846814546..01ee8703a9 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -43,7 +43,7 @@ class BatchOpsTest(test.TestCase):
def testBasicBatch(self):
"""Tests that a single batched tensor executes together and only once."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -83,7 +83,7 @@ class BatchOpsTest(test.TestCase):
def testBatchWithPadding(self):
"""Test that batching with padding up to an allowed batch size works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -113,7 +113,7 @@ class BatchOpsTest(test.TestCase):
def testMultipleBatch(self):
"""Tests that multiple batched tensors execute together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, _, _ = batch_ops.batch(
@@ -152,7 +152,7 @@ class BatchOpsTest(test.TestCase):
def testIllegalBatchDifferentDim0Sizes(self):
"""Tests illegally feeding tensors with different dim0 sizes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
batched, index, _ = batch_ops.batch(
@@ -166,7 +166,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatch(self):
"""Tests that batch and unbatch work together."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=10,
@@ -190,7 +190,8 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchV1Decorated(self):
"""Tests that the batch_function_v1 decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
+
@batch_ops.batch_function_v1(1, 10, 100000)
def computation(in_t):
return in_t + 1
@@ -211,7 +212,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# TODO(apassos): Removing this line causes test flakiness! Ideally should
# be investigated.
default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
@@ -236,7 +237,7 @@ class BatchOpsTest(test.TestCase):
def testBatchDecoratedWithCapturedInput(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
@@ -260,7 +261,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOp(self):
"""Tests that the batch_function op works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@function.Defun(dtypes.int32)
def computation(in_t):
@@ -289,7 +290,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOpWithCapturedInput(self):
"""Tests that batch_function op works with captured input."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@@ -323,7 +324,7 @@ class BatchOpsTest(test.TestCase):
def testBatchFunctionOpWithInputError(self):
"""Tests that batch_function op works with error in the inputs."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
@function.Defun(dtypes.int32, dtypes.int32)
@@ -346,7 +347,7 @@ class BatchOpsTest(test.TestCase):
def testBasicUnbatchDecoratedWithReshape(self):
"""Tests that the batch_function decorator works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
@@ -368,7 +369,7 @@ class BatchOpsTest(test.TestCase):
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
@@ -410,7 +411,7 @@ class BatchOpsTest(test.TestCase):
def testUnbatchGrad(self):
"""Tests that batch and unbatch are differentiable."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1])
batched, index, id_t = batch_ops.batch(
[inp], num_batch_threads=1, max_batch_size=2,
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index 9e6a146f67..13215ffabf 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -42,7 +42,7 @@ class ExpectationImportanceSampleTest(test.TestCase):
def test_normal_integral_mean_and_var_correctly_estimated(self):
n = int(1e6)
- with self.test_session():
+ with self.cached_session():
mu_p = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
mu_q = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([0.5, 0.5], dtype=dtypes.float64)
@@ -72,7 +72,7 @@ class ExpectationImportanceSampleTest(test.TestCase):
# Test that importance sampling can correctly estimate the probability that
# the product of components in a MultivariateNormal are > 0.
n = 1000
- with self.test_session():
+ with self.cached_session():
p = mvn_diag_lib.MultivariateNormalDiag(
loc=[0.], scale_diag=[1.0, 1.0])
q = mvn_diag_lib.MultivariateNormalDiag(
@@ -99,7 +99,7 @@ class ExpectationImportanceSampleLogspaceTest(test.TestCase):
def test_normal_distribution_second_moment_estimated_correctly(self):
# Test the importance sampled estimate against an analytical result.
n = int(1e6)
- with self.test_session():
+ with self.cached_session():
mu_p = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
mu_q = constant_op.constant([-1.0, 1.0], dtype=dtypes.float64)
sigma_p = constant_op.constant([1.0, 2 / 3.], dtype=dtypes.float64)
@@ -127,7 +127,7 @@ class GetSamplesTest(test.TestCase):
"""Test the private method 'get_samples'."""
def test_raises_if_both_z_and_n_are_none(self):
- with self.test_session():
+ with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = None
n = None
@@ -136,7 +136,7 @@ class GetSamplesTest(test.TestCase):
_get_samples(dist, z, n, seed)
def test_raises_if_both_z_and_n_are_not_none(self):
- with self.test_session():
+ with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = dist.sample(seed=42)
n = 1
@@ -145,7 +145,7 @@ class GetSamplesTest(test.TestCase):
_get_samples(dist, z, n, seed)
def test_returns_n_samples_if_n_provided(self):
- with self.test_session():
+ with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = None
n = 10
@@ -154,7 +154,7 @@ class GetSamplesTest(test.TestCase):
self.assertEqual((10,), z.get_shape())
def test_returns_z_if_z_provided(self):
- with self.test_session():
+ with self.cached_session():
dist = normal_lib.Normal(loc=0., scale=1.)
z = dist.sample(10, seed=42)
n = None
@@ -166,7 +166,7 @@ class GetSamplesTest(test.TestCase):
class ExpectationTest(test.TestCase):
def test_works_correctly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant([-1e6, -100, -10, -1, 1, 10, 100, 1e6])
p = normal_lib.Normal(loc=x, scale=1.)
@@ -213,7 +213,7 @@ class ExpectationTest(test.TestCase):
rtol=0.05, atol=0.)
def test_docstring_example_normal(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_draws = int(1e5)
mu_p = constant_op.constant(0.)
mu_q = constant_op.constant(1.)
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 9afe3df585..18d40fc1df 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.util import deprecation
__all__ = [
'expectation',
@@ -66,7 +67,7 @@ def expectation_importance_sampler(f,
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `sampling_dist_q.log_prob`.
sampling_dist_q: The sampling distribution.
- `tf.contrib.distributions.Distribution`.
+ `tfp.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
@@ -141,7 +142,7 @@ def expectation_importance_sampler_logspace(
shape broadcastable to `q.batch_shape`.
For example, `log_p` works "just like" `q.log_prob`.
sampling_dist_q: The sampling distribution.
- `tf.contrib.distributions.Distribution`.
+ `tfp.distributions.Distribution`.
`float64` `dtype` recommended.
`log_p` and `q` should be supported on the same set.
z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`.
@@ -188,6 +189,12 @@ def _logspace_mean(log_values):
return log_mean_of_values
+@deprecation.deprecated(
+ '2018-10-01',
+ 'The tf.contrib.bayesflow library has moved to '
+ 'TensorFlow Probability (https://github.com/tensorflow/probability). '
+ 'Use `tfp.monte_carlo.expectation` instead.',
+ warn_once=True)
def expectation(f, samples, log_prob=None, use_reparametrization=True,
axis=0, keep_dims=False, name=None):
r"""Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\).
@@ -236,17 +243,17 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
Example Use:
```python
- bf = tf.contrib.bayesflow
- ds = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Monte-Carlo approximation of a reparameterized distribution, e.g., Normal.
num_draws = int(1e5)
- p = ds.Normal(loc=0., scale=1.)
- q = ds.Normal(loc=1., scale=2.)
- exact_kl_normal_normal = ds.kl_divergence(p, q)
+ p = tfd.Normal(loc=0., scale=1.)
+ q = tfd.Normal(loc=1., scale=2.)
+ exact_kl_normal_normal = tfd.kl_divergence(p, q)
# ==> 0.44314718
- approx_kl_normal_normal = bf.expectation(
+ approx_kl_normal_normal = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
@@ -260,9 +267,9 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
num_draws = int(1e5)
p = ds.Gamma(concentration=1., rate=1.)
q = ds.Gamma(concentration=2., rate=3.)
- exact_kl_gamma_gamma = ds.kl_divergence(p, q)
+ exact_kl_gamma_gamma = tfd.kl_divergence(p, q)
# ==> 0.37999129
- approx_kl_gamma_gamma = bf.expectation(
+ approx_kl_gamma_gamma = tfp.monte_carlo.expectation(
f=lambda x: p.log_prob(x) - q.log_prob(x),
samples=p.sample(num_draws, seed=42),
log_prob=p.log_prob,
@@ -278,7 +285,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True,
KL-divergence, the following is preferred:
```python
- approx_kl_p_q = bf.monte_carlo_csiszar_f_divergence(
+ approx_kl_p_q = tfp.vi.monte_carlo_csiszar_f_divergence(
f=bf.kl_reverse,
p_log_prob=q.log_prob,
q=p,
diff --git a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
index e36f7f32c6..316da9ebe1 100644
--- a/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
+++ b/tensorflow/contrib/bigtable/python/kernel_tests/bigtable_ops_test.py
@@ -61,7 +61,7 @@ class BigtableOpsTest(test.TestCase):
n = itr.get_next()
expected = list(self.COMMON_ROW_KEYS)
expected.reverse()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i in range(3):
@@ -84,7 +84,7 @@ class BigtableOpsTest(test.TestCase):
expected_keys.reverse()
expected_values = list(self.COMMON_VALUES)
expected_values.reverse()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i in range(3):
@@ -125,7 +125,7 @@ class BigtableOpsTest(test.TestCase):
expected_keys = list(self.COMMON_ROW_KEYS)
expected_values = list(self.COMMON_VALUES)
expected_tuples = zip(expected_keys, expected_values)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i, elem in enumerate(expected_tuples):
@@ -144,7 +144,7 @@ class BigtableOpsTest(test.TestCase):
itr = ds.make_initializable_iterator()
n = itr.get_next()
expected_key = self.COMMON_ROW_KEYS[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
output = sess.run(n)
@@ -163,7 +163,7 @@ class BigtableOpsTest(test.TestCase):
def runSampleKeyPairsTest(self, ds, expected_key_pairs):
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
for i, elems in enumerate(expected_key_pairs):
@@ -219,7 +219,7 @@ class BigtableOpsTest(test.TestCase):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="r1", end="")
itr = ds.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(itr.initializer)
@@ -227,7 +227,7 @@ class BigtableOpsTest(test.TestCase):
ds = bigtable_api._BigtableSampleKeyPairsDataset(
self._table, prefix="r", start="", end="r3")
itr = ds.make_initializable_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(itr.initializer)
@@ -235,7 +235,7 @@ class BigtableOpsTest(test.TestCase):
ds = self._table.parallel_scan_prefix(prefix="r", cf1="c1")
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
@@ -253,7 +253,7 @@ class BigtableOpsTest(test.TestCase):
ds = self._table.parallel_scan_range(start="r1", end="r4", cf1="c1")
itr = ds.make_initializable_iterator()
n = itr.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._writeCommonValues(sess)
sess.run(itr.initializer)
expected_values = list(zip(self.COMMON_ROW_KEYS, self.COMMON_VALUES))
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index 3e1b622867..cf56822ff4 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -575,7 +575,7 @@ def _normalize_columns(columns, provided_kwargs):
return normalized
-class _BigtableKeyDataset(dataset_ops.Dataset):
+class _BigtableKeyDataset(dataset_ops.DatasetSource):
"""_BigtableKeyDataset is an abstract class representing the keys of a table.
"""
@@ -645,7 +645,7 @@ class _BigtableSampleKeysDataset(_BigtableKeyDataset):
table=self._table._resource) # pylint: disable=protected-access
-class _BigtableLookupDataset(dataset_ops.Dataset):
+class _BigtableLookupDataset(dataset_ops.DatasetSource):
"""_BigtableLookupDataset represents a dataset that retrieves values for keys.
"""
@@ -678,7 +678,7 @@ class _BigtableLookupDataset(dataset_ops.Dataset):
columns=self._columns)
-class _BigtableScanDataset(dataset_ops.Dataset):
+class _BigtableScanDataset(dataset_ops.DatasetSource):
"""_BigtableScanDataset represents a dataset that retrieves keys and values.
"""
@@ -715,7 +715,7 @@ class _BigtableScanDataset(dataset_ops.Dataset):
probability=self._probability)
-class _BigtableSampleKeyPairsDataset(dataset_ops.Dataset):
+class _BigtableSampleKeyPairsDataset(dataset_ops.DatasetSource):
"""_BigtableSampleKeyPairsDataset returns key pairs from a Bigtable table.
"""
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index 5fcb19a47a..14b6fc4ac2 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -173,6 +173,7 @@ py_library(
py_test(
name = "dnn_tree_combined_estimator_test",
size = "medium",
+ timeout = "long",
srcs = ["dnn_tree_combined_estimator_test.py"],
srcs_version = "PY2AND3",
tags = [
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index 78232fa0a6..48f12a64f9 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -51,6 +51,7 @@ def make_custom_export_strategy(name,
feature_columns: A list of feature columns.
export_input_fn: A function that takes no arguments and returns an
`InputFnOps`.
+ use_core_columns: A boolean, whether core feature columns were used.
Returns:
An `ExportStrategy`.
@@ -196,7 +197,7 @@ def convert_to_universal_format(dtec, sorted_feature_names,
matching_id.int64_value = split.feature_id
node.custom_left_child_test.Pack(categorical_test)
else:
- raise ValueError("Unexpected node type %s", node_type)
+ raise ValueError("Unexpected node type %s" % node_type)
node.left_child_id.value = split.left_id
node.right_child_id.value = split.right_id
return model_and_features
@@ -236,7 +237,7 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats,
assert tree_node.node_metadata.gain == 0
continue
else:
- raise ValueError("Unexpected split type %s", node_type)
+ raise ValueError("Unexpected split type %s" % node_type)
# Apply shrinkage factor. It is important since it is not always uniform
# across different trees.
sums[split_column] += (
diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
index 1375fddf2b..606da663dc 100644
--- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
@@ -296,8 +296,9 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel {
int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
- ResourceHandle handle = resource_handle_list[resource_handle_idx]
- .flat<ResourceHandle>()(0);
+ const ResourceHandle& handle =
+ resource_handle_list[resource_handle_idx]
+ .flat<ResourceHandle>()(0);
QuantileStreamResource* streams_resource;
// Create a reference to the underlying resource using the handle.
OP_REQUIRES_OK(context,
@@ -709,8 +710,9 @@ class QuantileAccumulatorGetBucketsOp : public OpKernel {
&buckets_list, stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
- ResourceHandle handle = resource_handle_list[resource_handle_idx]
- .flat<ResourceHandle>()(0);
+ const ResourceHandle& handle =
+ resource_handle_list[resource_handle_idx]
+ .flat<ResourceHandle>()(0);
QuantileStreamResource* streams_resource;
OP_REQUIRES_OK(context,
LookupResource(context, handle, &streams_resource));
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 3b28ed77f3..8edb5d6c64 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -579,13 +579,6 @@ class BuildSparseInequalitySplitsOp : public OpKernel {
const int end_index =
partition_boundaries[non_empty_partitions[root_idx]][j + 1]
.start_index;
- CHECK(bucket_ids_and_dimensions(start_index, 1) ==
- bucket_ids_and_dimensions(end_index - 1, 1))
- << "For bucket " << bucket_ids_and_dimensions(start_index, 0)
- << " the dimension was "
- << bucket_ids_and_dimensions(start_index, 1) << " and for "
- << bucket_ids_and_dimensions(end_index - 1, 0) << " "
- << bucket_ids_and_dimensions(end_index - 1, 1);
if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) {
// 0-dimension case which has a first bucket for catch all feature.
CHECK(bucket_ids_and_dimensions(start_index, 1) == 0)
@@ -746,21 +739,22 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
- std::vector<int32> non_empty_partitions;
- for (int i = 0; i < partition_ids.size() - 1; ++i) {
+ partition_boundaries.push_back(0);
+ for (int i = 1; i < partition_ids.size(); ++i) {
// Make sure the input is sorted by partition_ids;
- CHECK_LE(partition_ids(i), partition_ids(i + 1));
- if (i == 0 || partition_ids(i) != partition_ids(i - 1)) {
+ OP_REQUIRES(context, partition_ids(i - 1) <= partition_ids(i),
+ errors::InvalidArgument("Partition IDs must be sorted."));
+ if (partition_ids(i) != partition_ids(i - 1)) {
partition_boundaries.push_back(i);
- // Some partitions might only have bias feature. We don't want to split
- // those so check that the partition has at least 2 features.
- if (partition_ids(i) == partition_ids(i + 1)) {
- non_empty_partitions.push_back(partition_boundaries.size() - 1);
- }
}
}
- if (partition_ids.size() > 0) {
- partition_boundaries.push_back(partition_ids.size());
+ std::vector<int32> non_empty_partitions;
+ partition_boundaries.push_back(partition_ids.size());
+ for (int i = 0; i < partition_boundaries.size() - 1; ++i) {
+ // We want to ignore partitions with only the bias term.
+ if (partition_boundaries[i + 1] - partition_boundaries[i] >= 2) {
+ non_empty_partitions.push_back(i);
+ }
}
int num_elements = non_empty_partitions.size();
Tensor* output_partition_ids_t = nullptr;
@@ -862,6 +856,15 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
auto* equality_split = split_info.mutable_split_node()
->mutable_categorical_id_binary_split();
equality_split->set_feature_column(state->feature_column_group_id());
+ CHECK(feature_ids(best_feature_idx, 0) != bias_feature_id)
+ << "Unexpected feature ID selected. "
+ << "Start feature ID: [" << start_index << "] "
+ << feature_ids(start_index, 0) << ", " << feature_ids(start_index, 1)
+ << "\nBest feature ID: [" << best_feature_idx << "] "
+ << feature_ids(best_feature_idx, 0) << ", "
+ << feature_ids(best_feature_idx, 1)
+ << "\nPartition IDS: " << partition_ids(start_index) << " "
+ << partition_ids(best_feature_idx);
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
index 90a0655201..e446c411a8 100644
--- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc
@@ -448,8 +448,9 @@ class StatsAccumulatorScalarAddOp : public OpKernel {
stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
- ResourceHandle handle = resource_handle_list[resource_handle_idx]
- .flat<ResourceHandle>()(0);
+ const ResourceHandle& handle =
+ resource_handle_list[resource_handle_idx]
+ .flat<ResourceHandle>()(0);
StatsAccumulatorScalarResource* accumulator_resource;
OP_REQUIRES_OK(context, LookupResource(context, handle,
@@ -512,8 +513,9 @@ class StatsAccumulatorTensorAddOp : public OpKernel {
stamp_token](int64 start, int64 end) {
for (int resource_handle_idx = start; resource_handle_idx < end;
++resource_handle_idx) {
- ResourceHandle handle = resource_handle_list[resource_handle_idx]
- .flat<ResourceHandle>()(0);
+ const ResourceHandle& handle =
+ resource_handle_list[resource_handle_idx]
+ .flat<ResourceHandle>()(0);
StatsAccumulatorTensorResource* accumulator_resource;
OP_REQUIRES_OK(context, LookupResource(context, handle,
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index 35d727482b..4da25298cb 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
-_BIAS_FEATURE_ID = -1
+_BIAS_FEATURE_ID = int(dtypes.int64.min)
class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index 94ea7bc2eb..a2f708081a 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -170,7 +170,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testObliviousFeatureSplitGeneration(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 1 | 1 |
@@ -577,6 +577,92 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(gains), 0)
self.assertEqual(len(splits), 0)
+ def testLastOneEmpty(self):
+ with self.cached_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 0 | 1,2 |
+ # i1 | (-0.5, 0.07) | 0 | |
+ # i2 | (1.2, 0.2) | 0 | 2 |
+ # i3 | (4.0, 0.13) | 1 | |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [0, 0, 0, 1]
+ indices = [[0, 0], [0, 1], [2, 0]]
+ values = array_ops.constant([1, 2, 2], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([0], partitions)
+
+ # Check the split on partition 0.
+ # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
+ expected_left_weight = -0.9848484848484846
+
+ # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
+ expected_left_gain = 1.2803030303030298
+
+ # -(-0.5 + 0.1) / (0.07 + 1)
+ expected_right_weight = 0.37383177570093457
+
+ # (-0.5 + 0.1) ** 2 / (0.07 + 1)
+ expected_right_gain = 0.14953271028037385
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain = 0.46043165467625885
+
+ split_info = split_info_pb2.SplitInfo()
+ split_info.ParseFromString(splits[0])
+ left_child = split_info.left_child.vector
+ right_child = split_info.right_child.vector
+ split_node = split_info.split_node.categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+
+ self.assertEqual(2, split_node.feature_id)
+
+ self.assertAllClose(
+ expected_left_gain + expected_right_gain - expected_bias_gain, gains[0],
+ 0.00001)
+
+ self.assertAllClose([expected_left_weight], left_child.value, 0.00001)
+
+ self.assertAllClose([expected_right_weight], right_child.value, 0.00001)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 73e41bc457..9d9941f696 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -86,7 +86,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testExtractFeatures(self):
"""Tests feature extraction."""
- with self.test_session():
+ with self.cached_session():
features = {}
features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32)
features["sparse_float"] = sparse_tensor.SparseTensor(
@@ -128,7 +128,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testExtractFeaturesWithTransformation(self):
"""Tests feature extraction."""
- with self.test_session():
+ with self.cached_session():
features = {}
features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32)
features["sparse_float"] = sparse_tensor.SparseTensor(
@@ -178,7 +178,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testExtractFeaturesFromCoreFeatureColumns(self):
"""Tests feature extraction when using core columns."""
- with self.test_session():
+ with self.cached_session():
features = {}
# Sparse float column does not exist in core, so only dense numeric and
# categorical.
@@ -213,7 +213,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefNoBiasCentering(self):
"""Tests the train function running on chief without bias centering."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -316,7 +316,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
self.assertProtoEquals(expected_tree, output.trees[0])
def testObliviousDecisionTreeAsWeakLearner(self):
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -473,7 +473,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefSparseAndDense(self):
"""Tests the train function with sparse and dense features."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -580,7 +580,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefScalingNumberOfExamples(self):
"""Tests the train function running on chief without bias centering."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -685,7 +685,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefWithBiasCentering(self):
"""Tests the train function running on chief with bias centering."""
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -757,7 +757,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnNonChiefNoBiasCentering(self):
"""Tests the train function running on worker without bias centering."""
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -821,7 +821,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnNonChiefWithCentering(self):
"""Tests the train function running on worker with bias centering."""
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -885,7 +885,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testPredictFn(self):
"""Tests the predict function."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create ensemble with one bias node.
ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
@@ -939,7 +939,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testPredictFnWithLeafIndexAdvancedLeft(self):
"""Tests the predict function with output leaf ids."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create ensemble with one bias node.
ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
@@ -1051,7 +1051,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnMulticlassFullHessian(self):
"""Tests the GBDT train for multiclass full hessian."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1155,7 +1155,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnMulticlassDiagonalHessian(self):
"""Tests the GBDT train for multiclass diagonal hessian."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1259,7 +1259,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnMulticlassTreePerClass(self):
"""Tests the GBDT train for multiclass tree per class strategy."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1374,7 +1374,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self):
"""Tests the train function running on chief with feature selection."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -1493,7 +1493,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefFeatureSelectionWithGoodSplits(self):
"""Tests the train function running on chief with feature selection."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
@@ -1610,7 +1610,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testTrainFnChiefFeatureSelectionReachedLimitIncrementAttemptedLayer(self):
"""Tests the train function running on chief with feature selection."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree = tree_ensemble_config.trees.add()
@@ -1720,7 +1720,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testResetModelBeforeAndAfterSplit(self):
"""Tests whether resetting works."""
- with self.test_session():
+ with self.cached_session():
# First build a small tree and train it to verify training works.
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
@@ -1854,7 +1854,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testResetModelNonChief(self):
"""Tests the reset function on a non-chief worker."""
- with self.test_session():
+ with self.cached_session():
# Create ensemble with one bias node.
ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
@@ -1930,7 +1930,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
def testResetModelWithCenterBias(self):
"""Tests the reset function running on chief with bias centering."""
- with self.test_session():
+ with self.cached_session():
ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
learner_config = learner_pb2.LearnerConfig()
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
index ccb8509c03..cc22504c8f 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py
@@ -45,7 +45,7 @@ class LossesTest(test_util.TensorFlowTestCase):
eps = 0.2
- with self.test_session():
+ with self.cached_session():
predictions_tensor = constant_op.constant(
prediction_logits, dtype=dtypes.float32)
loss_for_positives, _ = losses.per_example_exp_loss(
@@ -84,7 +84,7 @@ class LossesTest(test_util.TensorFlowTestCase):
predictions = np.array(
[[0.123], [23.2], [233], [52], [3]], dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
loss_tensor, _ = losses.per_example_squared_loss(labels, weights,
predictions)
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 150d734db6..94b7f4f867 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -37,6 +37,7 @@ Checkpoint management:
Saving and restoring Python state:
@@NumpyState
+@@PythonStateWrapper
"""
from __future__ import absolute_import
@@ -45,6 +46,7 @@ from __future__ import print_function
from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.python_state import NumpyState
+from tensorflow.contrib.checkpoint.python.python_state import PythonStateWrapper
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
diff --git a/tensorflow/contrib/checkpoint/python/python_state.py b/tensorflow/contrib/checkpoint/python/python_state.py
index 9b11035b6d..302d5cfb79 100644
--- a/tensorflow/contrib/checkpoint/python/python_state.py
+++ b/tensorflow/contrib/checkpoint/python/python_state.py
@@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import functools
+import six
import numpy
@@ -101,7 +103,7 @@ class NumpyState(base.CheckpointableBase):
# TODO(allenl): Consider supporting lists/tuples, either ad-hoc or by making
# ndarrays checkpointable natively and using standard checkpointable list
# tracking.
- if isinstance(value, numpy.ndarray):
+ if isinstance(value, (numpy.ndarray, numpy.generic)):
try:
existing = super(NumpyState, self).__getattribute__(name)
existing.array = value
@@ -127,7 +129,29 @@ class NumpyState(base.CheckpointableBase):
super(NumpyState, self).__setattr__(name, value)
-class _NumpyWrapper(base.CheckpointableBase):
+@six.add_metaclass(abc.ABCMeta)
+class PythonStateWrapper(base.CheckpointableBase):
+ """Wraps a Python object for storage in an object-based checkpoint."""
+
+ @abc.abstractmethod
+ def _serialize(self):
+ """Callback for `PythonStringStateSaveable` to serialize the object."""
+
+ @abc.abstractmethod
+ def _deserialize(self, string_value):
+ """Callback for `PythonStringStateSaveable` to deserialize the object."""
+
+ def _gather_saveables_for_checkpoint(self):
+ """Specify callbacks for saving and restoring `array`."""
+ return {
+ "py_state": functools.partial(
+ base.PythonStringStateSaveable,
+ state_callback=self._serialize,
+ restore_callback=self._deserialize)
+ }
+
+
+class _NumpyWrapper(PythonStateWrapper):
"""Wraps a NumPy array for storage in an object-based checkpoint."""
def __init__(self, array):
@@ -139,7 +163,7 @@ class _NumpyWrapper(base.CheckpointableBase):
self.array = array
def _serialize(self):
- """Callback for `PythonStringStateSaveable` to serialize the array."""
+ """Callback to serialize the array."""
string_file = BytesIO()
try:
numpy.save(string_file, self.array, allow_pickle=False)
@@ -149,18 +173,10 @@ class _NumpyWrapper(base.CheckpointableBase):
return serialized
def _deserialize(self, string_value):
- """Callback for `PythonStringStateSaveable` to deserialize the array."""
+ """Callback to deserialize the array."""
string_file = BytesIO(string_value)
try:
self.array = numpy.load(string_file, allow_pickle=False)
finally:
string_file.close()
- def _gather_saveables_for_checkpoint(self):
- """Specify callbacks for saving and restoring `array`."""
- return {
- "array": functools.partial(
- base.PythonStringStateSaveable,
- state_callback=self._serialize,
- restore_callback=self._deserialize)
- }
diff --git a/tensorflow/contrib/checkpoint/python/python_state_test.py b/tensorflow/contrib/checkpoint/python/python_state_test.py
index 0439a4755e..45494351ff 100644
--- a/tensorflow/contrib/checkpoint/python/python_state_test.py
+++ b/tensorflow/contrib/checkpoint/python/python_state_test.py
@@ -40,10 +40,13 @@ class NumpyStateTests(test.TestCase):
save_state.a = numpy.ones([2, 2])
save_state.b = numpy.ones([2, 2])
save_state.b = numpy.zeros([2, 2])
+ save_state.c = numpy.int64(3)
self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
+ self.assertEqual(3, save_state.c)
first_save_path = saver.save(prefix)
save_state.a[1, 1] = 2.
+ save_state.c = numpy.int64(4)
second_save_path = saver.save(prefix)
load_state = python_state.NumpyState()
@@ -51,6 +54,7 @@ class NumpyStateTests(test.TestCase):
loader.restore(first_save_path).initialize_or_restore()
self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ self.assertEqual(3, load_state.c)
load_state.a[0, 0] = 42.
self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
loader.restore(first_save_path).run_restore_ops()
@@ -58,6 +62,7 @@ class NumpyStateTests(test.TestCase):
loader.restore(second_save_path).run_restore_ops()
self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
+ self.assertEqual(4, load_state.c)
def testNoGraphPollution(self):
graph = ops.Graph()
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 0b79f718d4..77242b34fd 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -1,6 +1,10 @@
TensorFlow CMake build
======================
+CMAKE build is deprecated for TensorFlow. Please use `bazel` to build TF for all
+platforms. For details, see the
+[TensorFlow install guide](https://www.tensorflow.org/install/).
+
This directory contains CMake files for building TensorFlow on Microsoft
Windows. [CMake](https://cmake.org) is a cross-platform tool that can
generate build scripts for multiple build systems, including Microsoft
@@ -13,7 +17,7 @@ Linux.
Current Status
--------------
-CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/install_windows)
+CMake can be used to build TensorFlow on Windows. See the [getting started documentation](https://www.tensorflow.org/install/source_windows)
for instructions on how to install a pre-built TensorFlow package on Windows.
### Current known limitations
diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake
index ad2af01bc0..1a147e9c8e 100644
--- a/tensorflow/contrib/cmake/external/png.cmake
+++ b/tensorflow/contrib/cmake/external/png.cmake
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
include (ExternalProject)
+include (GNUInstallDirs)
set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive)
set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz)
@@ -35,7 +36,7 @@ if(WIN32)
endif()
endif()
else()
- set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a)
+ set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/${CMAKE_INSTALL_LIBDIR}/libpng16.a)
endif()
set(png_HEADERS
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index fb871acae9..c0763f4c0e 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -273,9 +273,6 @@ tensorflow/contrib/libsvm
tensorflow/contrib/libsvm/python
tensorflow/contrib/libsvm/python/kernel_tests
tensorflow/contrib/libsvm/python/ops
-tensorflow/contrib/linalg
-tensorflow/contrib/linalg/python
-tensorflow/contrib/linalg/python/ops
tensorflow/contrib/linear_optimizer
tensorflow/contrib/linear_optimizer/kernels
tensorflow/contrib/linear_optimizer/kernels/g3doc
@@ -409,7 +406,6 @@ tensorflow/contrib/summary
tensorflow/contrib/tensorboard
tensorflow/contrib/tensorboard/plugins
tensorflow/contrib/tensorboard/plugins/projector
-tensorflow/contrib/tensorboard/plugins/trace
# TODO(sami): Add cmake implementations.
# tensorflow/contrib/tensorrt/python
# tensorflow/contrib/tensorrt/python/ops
diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt
index cf1ee2ad76..42afbd9105 100644
--- a/tensorflow/contrib/cmake/python_protos.txt
+++ b/tensorflow/contrib/cmake/python_protos.txt
@@ -12,7 +12,6 @@ tensorflow/contrib/mpi_collectives
tensorflow/contrib/session_bundle
tensorflow/contrib/tensor_forest/proto
tensorflow/contrib/tensorboard/plugins/projector
-tensorflow/contrib/tensorboard/plugins/trace
tensorflow/contrib/tpu/proto
tensorflow/contrib/tpu/profiler
tensorflow/contrib/training/python/training
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 2c878c1716..ed31351d9e 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -183,7 +183,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
file(GLOB_RECURSE tf_test_src_py
${tf_test_src_py}
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
- "${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
diff --git a/tensorflow/contrib/coder/python/ops/coder_ops_test.py b/tensorflow/contrib/coder/python/ops/coder_ops_test.py
index d5e14e7a64..f5431ca1ff 100644
--- a/tensorflow/contrib/coder/python/ops/coder_ops_test.py
+++ b/tensorflow/contrib/coder/python/ops/coder_ops_test.py
@@ -45,7 +45,7 @@ class CoderOpsTest(test.TestCase):
decoded = coder_ops.range_decode(
encoded, array_ops.shape(data), cdf, precision=14)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(*sess.run((data, decoded)))
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index d7583be6d8..f83386b8a4 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -5,7 +5,10 @@ package(default_visibility = [":friends"])
package_group(
name = "friends",
includes = ["//tensorflow/compiler/jit:friends"],
- packages = ["//tensorflow/..."],
+ packages = [
+ "//tensorflow/...",
+ "//third_party/py/tensor2tensor/...",
+ ],
)
load("//tensorflow:tensorflow.bzl", "tf_py_test")
@@ -53,12 +56,16 @@ py_library(
srcs = ["xla.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/compiler/jit:xla_ops_py",
+ "//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:platform",
+ "//tensorflow/python:summary_op_util",
"//tensorflow/python:util",
- "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:estimator_py",
],
)
diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py
index 42b3b9f026..3e631b5909 100644
--- a/tensorflow/contrib/compiler/jit_test.py
+++ b/tensorflow/contrib/compiler/jit_test.py
@@ -173,7 +173,7 @@ class JITTest(test.TestCase):
class CompilationEnabledInGradientTest(test.TestCase):
def testCompilationInGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([[3.]])
y_nc = math_ops.matmul(x, x, name="not_compiled")
with jit.experimental_jit_scope():
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 60f5af1662..1e30525159 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -12,20 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
-"""xla provides experimental xla support API."""
+"""xla is an experimental library that provides XLA support APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+import contextlib
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.compiler.jit.ops import xla_ops
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import summary_op_util
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
+from tensorflow.python.util import tf_decorator
_XLA_COMPILE_ATTR = '_xla_compile_id'
_MAX_WARNING_LINES = 5
@@ -51,6 +60,30 @@ _UNSUPPORTED_OPS = set([
])
+def compile(computation, inputs=None): # pylint: disable=redefined-builtin
+ """Builds an operator that compiles and runs `computation` with XLA.
+
+ Args:
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors.
+
+ `computation` may return a list of operations and tensors. Tensors must
+ come before operations in the returned list. The return value of
+ `compile` is a list of tensors corresponding to the tensors from the
+ output of `computation`.
+
+ All `Operation`s returned from `computation` will be executed when
+ evaluating any of the returned output tensors.
+ inputs: A list of input tensors or `None` (equivalent to an empty list).
+
+ Returns:
+ A list of output tensors.
+ """
+ # pylint: disable=protected-access
+ return _compile_internal(computation, inputs)
+
+
class XLACompileContext(control_flow_ops.XLAControlFlowContext):
"""A `ControlFlowContext` for nodes inside an XLA computation cluster.
@@ -206,3 +239,410 @@ class XLACompileContext(control_flow_ops.XLAControlFlowContext):
if self.GetWhileContext():
return self.GetWhileContext().back_prop
return False
+
+
+def _compile_internal(computation, inputs=None):
+ """Builds graph operators that compiles and symbolically executes computation.
+
+ Args:
+ computation: A Python function that builds the computation to compile and
+ execute.
+ inputs: A list of input tensors or `None` (equivalent to `[]`). Its order
+ should match ordering of computation arguments.
+ Returns:
+ A list of output tensors from computation.
+ Raises:
+ ValueError: If any element in computation outputs is neither an operations
+ or a value that can be converted to tensor.
+ TypeError: If `inputs` is not a list or tuple.
+ """
+ if inputs is None:
+ inputs = []
+
+ if not isinstance(inputs, collections.Sequence):
+ raise TypeError('inputs must be a list')
+
+ # Converts inputs to Tensors.
+ inputs = [ops.convert_to_tensor(x) for x in inputs]
+ input_arity = len(inputs)
+
+ arg_error = tpu_function.check_function_argument_count(
+ computation, input_arity, infeed_queue=None)
+ if arg_error is not None:
+ raise TypeError(
+ 'Supplied computation cannot be called with the specified inputs. You '
+ 'specified %d inputs: %s, but the computation needs %s' %
+ (input_arity, str([i.name for i in inputs[0]]), arg_error))
+
+ cluster_name = ops.get_default_graph().unique_name('cluster')
+ pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
+ context = XLACompileContext(name=cluster_name, pivot=pivot)
+ try:
+ context.Enter()
+
+ # Add identity ops so even unused inputs are 'consumed' by the
+ # computation.
+ computation_inputs = [
+ array_ops.identity(x, name='input_{}'.format(i))
+ for i, x in enumerate(inputs)
+ ]
+
+ # Only resource variables work inside an XLA computation, so turn on
+ # resource variables for the computation.
+ vscope = variable_scope.get_variable_scope()
+ saved_use_resource = vscope.use_resource
+ vscope.set_use_resource(True)
+
+ outputs = computation(*computation_inputs)
+
+ # Restore variable scope after computation.
+ vscope.set_use_resource(saved_use_resource)
+
+ # If the computation returns `None`, make it an empty tuple.
+ if outputs is None:
+ outputs = tuple()
+ # If the computation only returned one value, make it a tuple.
+ if not isinstance(outputs, collections.Sequence):
+ outputs = (outputs,)
+
+ # Append `no_op` here so that return value of this function always contains
+ # at least one op that can trigger XlaLaunch node.
+ outputs += (control_flow_ops.no_op(),)
+ try:
+ outputs = [
+ o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
+ for o in outputs
+ ]
+ except Exception as e:
+ raise ValueError(
+ 'XLA computation function return values must all either be Operations'
+ ' or convertible to Tensors. Got error: "%s"' % str(e))
+
+ # Separates the returned Operations and Tensors.
+ output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
+ output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
+
+ if outputs != output_tensors + output_operations:
+ raise ValueError(
+ 'XLA computation function must return zero or more Tensor values '
+ 'followed by zero or more Operations.')
+ output_arity = len(output_tensors)
+
+ new_output_tensors = []
+ for t in output_tensors:
+ with ops.device(t.device if t.device else ''):
+ new_output_tensors.append(array_ops.identity(t))
+
+ output_tensors = new_output_tensors
+ context.ExitResult(output_tensors)
+ finally:
+ context.report_unsupported_operations()
+ context.Exit()
+
+ outputs = [
+ xla_ops.xla_cluster_output(output_tensors[i], name='output{}'.format(i))
+ for i in xrange(output_arity)
+ ]
+
+ with ops.control_dependencies(output_operations):
+ if output_arity == 0:
+ # When XLA computation returns only operations and no tensors, a NoOp
+ # dependent on the operations in outputs is returned. Otherwise final
+ # outputs would be empty and there is no way to trigger returned
+ # operations.
+ return control_flow_ops.no_op(name='output_0')
+ else:
+ # Wraps the outputs in identity operators that carries control
+ # dependencies.
+ return [
+ array_ops.identity(outputs[i], name='output_%d' % i)
+ for i in xrange(output_arity)
+ ]
+
+
+@contextlib.contextmanager
+def _disable_summary_context():
+ """Enters a context where all summary ops are skipped.
+
+ Summaries are not yet supported in xla.compile(). So we provide this context
+ manager that can skip creating summary ops. This is a temporary workaround due
+ to XLA not supporting summary ops.
+
+ Yields:
+ None.
+ """
+ origional_skip_summary_func = summary_op_util.skip_summary
+ summary_op_util.skip_summary = lambda: True
+
+ try:
+ yield
+ finally:
+ summary_op_util.skip_summary = origional_skip_summary_func
+
+
+class _CapturedObject(object):
+ """A placeholder to capture an object."""
+
+ def __init__(self):
+ self._object = None
+
+ def capture(self, o):
+ if self._object:
+ raise RuntimeError(
+ 'InternalError: _CapturedObject can capture only once. Please file '
+ 'bug.')
+
+ self._object = o
+
+ def get(self):
+ return self._object
+
+
+def _get_scaffold(captured_scaffold_fn):
+ """Retrieves the Scaffold from `captured_scaffold_fn`."""
+ scaffold_fn = captured_scaffold_fn.get()
+
+ if not scaffold_fn:
+ return None
+
+ scaffold = scaffold_fn()
+ if scaffold is None:
+ raise ValueError(
+ 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
+
+ return scaffold
+
+
+class _ModelFnWrapper(object):
+ """_ModelFnWrapper supports executing model_fn with XLA."""
+
+ def __init__(self, function):
+ self._model_fn = function
+
+ def __call__(self, features, labels, mode, params):
+
+ # TPUEstimator compiles model_fn when use_tpu=True. To avoid double
+ # compilation, we use this params['use_tpu'] as a hint. When it is set to
+ # True, model_fn is called without compilation.
+ # Note that this condition isn't accurate for the case of exporting a model.
+ # In that case we should ideally not compile so that user can see detailed
+ # graph. However, we don't have enough information to tell whether model_fn
+ # is being called for export mode or not.
+ # TODO(ycao): Make this condition more accurate when implementing PREDICT
+ # mode.
+ if params.get('use_tpu'):
+ return self._call_model_fn(features, labels, mode, params)
+
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ train_step, captured_scaffold_fn = self._make_train_step(
+ features, labels, params)
+ with _disable_summary_context():
+ (loss,) = compile(train_step)
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ train_op=array_ops.identity(loss),
+ scaffold=_get_scaffold(captured_scaffold_fn))
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ eval_step, captured_eval_metric_fn, captured_scaffold_fn = (
+ self._make_eval_step(features, labels, params))
+ with _disable_summary_context():
+ outputs = compile(eval_step)
+ loss = outputs[0]
+
+ # Calculate eval_metric_ops if eval_metric_fn is set and captured.
+ eval_metric_fn = captured_eval_metric_fn.get()
+ if eval_metric_fn:
+ eval_metric_fn_tensors = outputs[1:]
+ eval_metric_ops = eval_metric_fn(*eval_metric_fn_tensors)
+ else:
+ eval_metric_ops = None
+
+ return model_fn_lib.EstimatorSpec(
+ mode=mode,
+ loss=loss,
+ eval_metric_ops=eval_metric_ops,
+ scaffold=_get_scaffold(captured_scaffold_fn))
+ else:
+ raise NotImplementedError('%s is not implemented, only TRAIN and EVAL are'
+ ' supported' % mode)
+
+ def _make_train_step(self, features, labels, params):
+ """Creates a single step of training for xla.compile()."""
+ captured_scaffold_fn = _CapturedObject()
+
+ def train_step():
+ """A single step of training."""
+ estimator_spec = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.TRAIN, params)
+
+ try:
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ except AttributeError:
+ captured_scaffold_fn.capture(None)
+
+ # train_step will be run by xla.compile(). xla.compile() only supports
+ # tensor output while train_op can be either an operation or a tensor.
+ # Even though xla.compile() automatically adds operation-typed train_op as
+ # control dependency of other tensor outputs, it doesn't do so for
+ # tensor-typed train_op. Thus, we need to set it explicitly here.
+ with ops.control_dependencies([estimator_spec.train_op]):
+ return array_ops.identity(estimator_spec.loss)
+
+ return train_step, captured_scaffold_fn
+
+ def _make_eval_step(self, features, labels, params):
+ """Creates a single step of evaluation for xla.compile()."""
+ captured_eval_metric_fn = _CapturedObject()
+ captured_scaffold_fn = _CapturedObject()
+
+ def eval_step():
+ """A single step of evaluation."""
+ estimator_spec = self._call_model_fn(features, labels,
+ model_fn_lib.ModeKeys.EVAL, params)
+
+ try:
+ captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
+ except AttributeError:
+ captured_scaffold_fn.capture(None)
+
+ eval_metric_fn = None
+ eval_metric_fn_tensors = []
+ try:
+ if estimator_spec.eval_metrics:
+ (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
+ except AttributeError:
+ pass
+
+ # If a dictionary is provided, we need to convert it into a list sorted
+ # according to order of eval_metric_fn positional arguments.
+ if isinstance(eval_metric_fn_tensors, dict):
+ eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
+ eval_metric_fn_tensors = [
+ eval_metric_fn_tensors[i] for i in eval_metric_fn_args
+ ]
+
+ captured_eval_metric_fn.capture(eval_metric_fn)
+
+ return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
+
+ return eval_step, captured_eval_metric_fn, captured_scaffold_fn
+
+ def _call_model_fn(self, features, labels, mode, params):
+ """Calls the model_fn with required parameters."""
+ model_fn_args = function_utils.fn_args(self._model_fn)
+ kwargs = {}
+
+ if 'labels' in model_fn_args:
+ kwargs['labels'] = labels
+ elif labels is not None:
+ raise ValueError(
+ 'model_fn does not take labels, but input_fn returns labels.')
+ if 'mode' in model_fn_args:
+ kwargs['mode'] = mode
+
+ if 'params' in model_fn_args:
+ kwargs['params'] = params
+
+ return self._verify_estimator_spec(
+ self._model_fn(features=features, **kwargs))
+
+ def _verify_estimator_spec(self, estimator_spec):
+ """Verifies estimator spec contains correct data."""
+ # TODO(ycao): Implement estimator spec verification for other modes.
+
+ try:
+ if estimator_spec.scaffold:
+ logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
+ '. Please use TPUEstimatorSpec.scaffold_fn instead.')
+ except AttributeError:
+ pass
+
+ try:
+ if estimator_spec.eval_metric_ops:
+ raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
+ 'XLA compilation. Please use '
+ 'TPUEstimatorSpec.eval_metrics instead.')
+ except AttributeError:
+ pass
+
+ if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
+ # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
+ # check that eval_metrics contains eval_metric_fn and
+ # eval_metric_fn_tensors with matching arguments.
+ try:
+ eval_metrics = estimator_spec.eval_metrics
+ except AttributeError:
+ eval_metrics = None
+
+ if eval_metrics:
+ (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
+ eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
+
+ if isinstance(eval_metric_fn_tensors, dict):
+ missing_tensors = [
+ i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
+ ]
+ additional_tensors = [
+ i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
+ ]
+
+ if missing_tensors:
+ raise ValueError('Arguments %s are needed by metric_fn (first '
+ 'element of TPUEstimatorSpec.eval_metrics) but '
+ 'they are not provided by evaluation tensors '
+ '(second element of TPUEstimatorSpec.eval_metrics)'
+ '.' % missing_tensors)
+
+ if additional_tensors:
+ raise ValueError('Arguments %s are provided by evaluation tensors '
+ '(second element of TPUEstimatorSpec.eval_metrics)'
+ ' but they are not needed by metric_fn (first '
+ 'element of TPUEstimatorSpec.eval_metrics).' %
+ additional_tensors)
+
+ return estimator_spec
+
+
+def estimator_model_fn(target_model_fn=None):
+ """estimator_model_fn decorates a model_fn to be compiled for execution.
+
+ Currently only it only works with `TPUEstimator`. If you need to use it with
+ base `Estimator`, please add `tf.enable_resource_variables()` at beginning of
+ your program.
+
+ Example 1, decorating model_fn:
+ ```
+ @xla.estimator_model_fn()
+ def model_fn(features, labels, mode, params):
+ ...
+ return EstimatorSpec(...)
+
+
+ est = Estimator(model_fn=model_fn, ...)
+ est.train(...)
+
+ ```
+
+ Example 2, decorator as function:
+ ```
+ def model_fn(features, labels, mode, params):
+ ...
+ return EstimatorSpec(...)
+
+ est = Estimator(model_fn=xla.estimator_model_fn(model_fn), ...)
+ est.train(...)
+ ```
+
+ Args:
+ target_model_fn: model_fn to be decorated. This is only needed when
+ decorator is used in function call form (example 2).
+
+ Returns:
+ Decorated target_model_fn.
+ """
+
+ def decorated(function):
+ return tf_decorator.make_decorator(function, _ModelFnWrapper(function))
+
+ return decorated(target_model_fn) if target_model_fn else decorated
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 5a667485be..c59d3682d4 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -413,6 +413,31 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase):
self._testOneLSTMParamsSize(num_layers, num_units, input_size,
direction)
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testLSTMParamsSizeShape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ constant_op.constant([4]), 200, 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, constant_op.constant([200]), 200,
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ model = _CreateModel(
+ cudnn_rnn_ops.CUDNN_LSTM,
+ 4, 200, constant_op.constant([200]),
+ direction=cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
+ params_size = model.params_size()
+
class CudnnRNNTestInference(TensorFlowTestCase):
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index fda1b9f1b3..57793a8ff5 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -460,7 +460,7 @@ class CudnnRNNTestBasic(test_util.TensorFlowTestCase):
grad, = gradients.gradients(
math_ops.reduce_sum(accumulation), (original_input,))
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
accumulation_eval, grad_eval = sess.run((accumulation, grad))
self.assertAllEqual([28, 100, 100], accumulation_eval.shape)
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index baec238c62..3cb51279c3 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -44,6 +44,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@group_by_reducer
@@group_by_window
@@ignore_errors
+@@latency_stats
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
@@ -57,11 +58,15 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@reduce_dataset
@@sample_from_datasets
@@scan
+@@set_stats_aggregator
@@shuffle_and_repeat
@@sliding_window_batch
@@sloppy_interleave
+@@StatsAggregator
@@unbatch
@@unique
+
+@@AUTOTUNE
"""
from __future__ import absolute_import
@@ -91,6 +96,10 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
+
from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -105,6 +114,9 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
+from tensorflow.contrib.data.python.ops.stats_ops import latency_stats
+from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator
+from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
@@ -113,6 +125,3 @@ from tensorflow.python.data.ops.optional_ops import Optional
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)
-
-# A constant that can be used to enable auto-tuning.
-AUTOTUNE = -1
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index 74107d5242..21ec50fb6b 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -49,6 +49,9 @@ class CSVDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx,
ctx->input_list("record_defaults", &record_defaults_list));
for (int i = 0; i < record_defaults_list.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 078de717e0..96f1dd0059 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -476,645 +476,6 @@ class IteratorGetDeviceOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
IteratorGetDeviceOp);
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
-string SanitizeThreadSuffix(string suffix) {
- string clean;
- for (int i = 0; i < suffix.size(); ++i) {
- const char ch = suffix[i];
- if ((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') ||
- (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') {
- clean += ch;
- } else {
- clean += '_';
- }
- }
- return clean;
-}
-
-struct HostBufferElement {
- Status status;
- bool end_of_sequence;
- std::vector<Tensor> value;
-};
-
-using MultiDeviceIteratorCallback =
- std::function<void(const HostBufferElement&)>;
-
-class MultiDeviceIterator : public ResourceBase {
- public:
- MultiDeviceIterator(const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes,
- const std::vector<string>& devices,
- std::unique_ptr<FunctionLibraryDefinition> flib_def,
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
- FunctionLibraryRuntime* lib)
- : output_types_(output_types),
- output_shapes_(output_shapes),
- devices_(devices),
- flib_def_(std::move(flib_def)),
- pflr_(std::move(pflr)),
- lib_(lib) {
- CHECK_NOTNULL(lib_);
- }
-
- string DebugString() override {
- return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
- " devices");
- }
-
- Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
- int64* incarnation_id) {
- if (iterator) {
- TF_RETURN_IF_ERROR(
- VerifyTypesMatch(output_types_, iterator->output_dtypes()));
- TF_RETURN_IF_ERROR(
- VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
- }
-
- mutex_lock l(mu_);
- if (multi_device_buffer_) {
- multi_device_buffer_->Reset();
- }
-
- ++incarnation_id_;
- *incarnation_id = incarnation_id_;
-
- multi_device_buffer_.reset(
- new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_,
- std::move(iterator)));
- return Status::OK();
- }
-
- void GetNextFromShard(IteratorContext* ctx, int shard_num,
- int64 incarnation_id,
- MultiDeviceIteratorCallback callback) {
- if (lib_ != nullptr) {
- ctx->set_lib(lib_);
- }
- tf_shared_lock l(mu_);
- multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
- std::move(callback));
- }
-
- const DataTypeVector& output_types() const { return output_types_; }
-
- const std::vector<PartialTensorShape>& output_shapes() const {
- return output_shapes_;
- }
-
- std::shared_ptr<const FunctionLibraryDefinition> function_library() {
- tf_shared_lock l(mu_);
- return lib_def_;
- }
-
- FunctionLibraryRuntime* const lib() {
- tf_shared_lock l(mu_);
- return lib_;
- }
-
- private:
- // A private class that uses a background thread to keep a per device buffer
- // full.
- class MultiDeviceBuffer {
- public:
- MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
- std::unique_ptr<IteratorBase> host_iterator)
- : buffer_(size),
- size_(size),
- max_buffer_size_(max_buffer_size),
- incarnation_id_(incarnation_id),
- host_iterator_(std::move(host_iterator)) {}
-
- ~MultiDeviceBuffer() { Reset(); }
-
- void Reset() LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- if (background_thread_finished_) {
- return;
- }
-
- cancelled_ = true;
- // Wake up the background thread.
- for (int i = 0; i < size_; ++i) {
- buffer_[i].cond_var.notify_all();
- }
-
- // Make sure background thread has finished first.
- while (!background_thread_finished_) {
- shutdown_cond_var_.wait(l);
- }
- }
- RunPendingCallbacks();
- }
-
- void GetNextFromShard(IteratorContext* ctx, int shard_num,
- int64 incarnation_id,
- MultiDeviceIteratorCallback callback) {
- HostBufferElement elem;
- if (incarnation_id_ != incarnation_id) {
- elem.status = errors::InvalidArgument("Invalid incarnation id");
- callback(elem);
- return;
- }
-
- bool produced_output = false;
- {
- mutex_lock l(mu_);
- if (cancelled_) {
- elem.status = errors::Cancelled("Cancelled Multidevice iterator");
- callback(elem);
- return;
- }
-
- EnsureBackgroundThreadStarted(ctx);
-
- if (!buffer_[shard_num].data.empty()) {
- produced_output = true;
- std::swap(elem, buffer_[shard_num].data.front());
- buffer_[shard_num].data.pop_front();
- // Wake up background thread if it is blocked on this element.
- if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
- buffer_[shard_num].cond_var.notify_all();
- }
- } else {
- if (background_thread_finished_) {
- produced_output = true;
- elem.end_of_sequence = true;
- } else {
- buffer_[shard_num].callbacks.push_back(std::move(callback));
- callback = nullptr;
- }
- }
- }
-
- if (produced_output) {
- callback(elem);
- }
- }
-
- private:
- void EnsureBackgroundThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (!background_thread_) {
- background_thread_.reset(ctx->env()->StartThread(
- {}, "multi_device_iterator_background_thread",
- std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
- this, new IteratorContext(*ctx))));
- }
- }
-
- void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
- // Run all remaining callbacks.
- std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
- std::vector<HostBufferElement> cancellation_elements;
- {
- mutex_lock l(mu_);
-
- for (int i = 0; i < size_; ++i) {
- while (!buffer_[i].callbacks.empty()) {
- if (buffer_[i].data.empty()) {
- HostBufferElement elem;
- elem.status =
- errors::Cancelled("Cancelled and buffer not filled.");
- cancellation_elements.push_back(std::move(elem));
- } else {
- cancellation_elements.push_back(
- std::move(buffer_[i].data.front()));
- buffer_[i].data.pop_front();
- }
- cancellation_callbacks.push_back(
- std::move(buffer_[i].callbacks.front()));
- buffer_[i].callbacks.pop_front();
- }
- }
- }
- for (int i = 0; i < cancellation_callbacks.size(); ++i) {
- cancellation_callbacks[i](cancellation_elements[i]);
- }
- }
-
- void BackgroundThread(IteratorContext* ctx) {
- std::unique_ptr<IteratorContext> cleanup(ctx);
- int shard_to_fetch = 0;
- while (true) {
- HostBufferElement elem;
- MultiDeviceIteratorCallback callback = nullptr;
- bool end_of_iterator = false;
-
- {
- mutex_lock l(mu_);
- while (!cancelled_ &&
- buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
- buffer_[shard_to_fetch].cond_var.wait(l);
- }
-
- if (cancelled_) {
- background_thread_finished_ = true;
- shutdown_cond_var_.notify_all();
- return;
- }
- }
-
- elem.status =
- host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence);
-
- if (elem.status.ok() && elem.end_of_sequence) {
- end_of_iterator = true;
- }
-
- {
- mutex_lock l(mu_);
- // Try to find a callback, else just push stuff into buffer.
- if (!buffer_[shard_to_fetch].callbacks.empty()) {
- callback = buffer_[shard_to_fetch].callbacks.front();
- buffer_[shard_to_fetch].callbacks.pop_front();
- } else {
- buffer_[shard_to_fetch].data.push_back(std::move(elem));
- elem = HostBufferElement();
- }
- }
-
- if (callback) {
- (*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
- }
-
- // Finish off the thread if we reach the end of the iterator. Runs
- // pending callbacks.
- if (end_of_iterator) {
- {
- mutex_lock l(mu_);
- background_thread_finished_ = true;
- shutdown_cond_var_.notify_all();
- }
- RunPendingCallbacks();
- return;
- }
- shard_to_fetch = (shard_to_fetch + 1) % size_;
- }
- }
-
- struct HostBuffer {
- condition_variable cond_var;
- std::deque<HostBufferElement> data;
- std::deque<MultiDeviceIteratorCallback> callbacks;
- };
-
- mutex mu_;
- std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
- bool background_thread_finished_ GUARDED_BY(mu_) = false;
- bool cancelled_ GUARDED_BY(mu_) = false;
- condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
-
- std::vector<HostBuffer> buffer_;
-
- const size_t size_;
- const int64 max_buffer_size_;
- const int64 incarnation_id_;
- const std::unique_ptr<IteratorBase> host_iterator_;
- };
-
- mutex mu_;
- const DataTypeVector output_types_;
- const std::vector<PartialTensorShape> output_shapes_;
- const std::vector<string> devices_;
- const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
- const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
- FunctionLibraryRuntime* const lib_ = nullptr; // not owned.
- std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
-
- int64 incarnation_id_ GUARDED_BY(mu_) = 0;
- std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
-};
-
-// Just creates a MultiDeviceIterator and returns it.
-class MultiDeviceIteratorHandleOp : public OpKernel {
- public:
- explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
- }
-
- // The resource is deleted from the resource manager only when it is private
- // to kernel.
- ~MultiDeviceIteratorHandleOp() override {
- if (resource_ != nullptr) {
- resource_->Unref();
- if (cinfo_.resource_is_private_to_kernel()) {
- if (!cinfo_.resource_manager()
- ->template Delete<MultiDeviceIterator>(cinfo_.container(),
- cinfo_.name())
- .ok()) {
- // Do nothing; the resource can have been deleted by session resets.
- }
- }
- }
- }
-
- void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- if (resource_ == nullptr) {
- FunctionLibraryRuntime* lib;
- std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
- OP_REQUIRES_OK(context, context->function_library()->Clone(
- &flib_def, &pflr, &lib));
- ResourceMgr* mgr = context->resource_manager();
- OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
-
- MultiDeviceIterator* resource;
- OP_REQUIRES_OK(
- context,
- mgr->LookupOrCreate<MultiDeviceIterator>(
- cinfo_.container(), cinfo_.name(), &resource,
- [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- *ret = new MultiDeviceIterator(
- output_types_, output_shapes_, devices_,
- std::move(flib_def), std::move(pflr), lib);
- return Status::OK();
- }));
-
- Status s = VerifyResource(resource);
- if (TF_PREDICT_FALSE(!s.ok())) {
- resource->Unref();
- context->SetStatus(s);
- return;
- }
-
- resource_ = resource;
- }
- }
- OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
- context, 0, cinfo_.container(), cinfo_.name(),
- MakeTypeIndex<MultiDeviceIterator>()));
- }
-
- private:
- // During the first Compute(), resource is either created or looked up using
- // shared_name. In the latter case, the resource found should be verified if
- // it is compatible with this op's configuration. The verification may fail in
- // cases such as two graphs asking queues of the same shared name to have
- // inconsistent capacities.
- Status VerifyResource(MultiDeviceIterator* resource) {
- TF_RETURN_IF_ERROR(
- VerifyTypesMatch(output_types_, resource->output_types()));
- TF_RETURN_IF_ERROR(
- VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
- return Status::OK();
- }
-
- mutex mu_;
- ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
- MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr;
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
- const int graph_def_version_;
- string name_;
- string container_;
- std::vector<string> devices_;
-};
-
-REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
- MultiDeviceIteratorHandleOp);
-
-// Calls init on the MultiDeviceIterator.
-class MultiDeviceIteratorInitOp : public OpKernel {
- public:
- explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor* tensor_max_buffer_size;
- OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
- int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
-
- DatasetBase* dataset;
- OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
- MultiDeviceIterator* resource;
- OP_REQUIRES_OK(ctx,
- LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
- core::ScopedUnref unref(resource);
-
- std::unique_ptr<IteratorBase> iterator;
- IteratorContext iter_ctx(ctx);
- iter_ctx.set_lib(resource->lib());
- OP_REQUIRES_OK(
- ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
- int64 incarnation_id;
- OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
- &incarnation_id));
- Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
- tensor_incarnation_id.scalar<int64>()() = incarnation_id;
- OP_REQUIRES_OK(ctx,
- ctx->set_output("incarnation_id", tensor_incarnation_id));
- }
-};
-
-REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
- MultiDeviceIteratorInitOp);
-
-// Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
-// TODO(rohanj): Implement using BackgroundWorker that Derek built?
-class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
- public:
- explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
- : AsyncOpKernel(ctx),
- thread_pool_(new thread::ThreadPool(
- ctx->env(), ThreadOptions(),
- strings::StrCat("multi_device_iterator_get_next_thread_",
- SanitizeThreadSuffix(name())),
- 1 /* num_threads */, false /* low_latency_hint */)) {}
-
- void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- const Tensor* tensor_shard_num;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
- int32 shard_num = tensor_shard_num->scalar<int32>()();
-
- const Tensor* tensor_incarnation_id;
- OP_REQUIRES_OK_ASYNC(
- ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
- int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
-
- MultiDeviceIterator* iterator;
- OP_REQUIRES_OK_ASYNC(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
- thread_pool_->Schedule(std::bind(
- [ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.function_library = iterator->function_library();
- DeviceBase* device = ctx->function_library()->device();
- params.allocator_getter = [device](AllocatorAttributes attrs) {
- return device->GetAllocator(attrs);
- };
- IteratorContext iter_ctx(std::move(params));
-
- MultiDeviceIteratorCallback callback = std::bind(
- [ctx](const HostBufferElement& elem, DoneCallback done) {
- // iterator->Unref();
- Status s = elem.status;
- if (!s.ok()) {
- ctx->SetStatus(s);
- } else if (elem.end_of_sequence) {
- ctx->SetStatus(errors::OutOfRange("End of sequence"));
- } else {
- for (int i = 0; i < elem.value.size(); ++i) {
- ctx->set_output(i, elem.value[i]);
- }
- }
- done();
- },
- std::placeholders::_1, std::move(done));
-
- iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
- callback);
- iterator->Unref();
- },
- std::move(done)));
- }
-
- private:
- std::unique_ptr<thread::ThreadPool> thread_pool_;
-};
-
-REGISTER_KERNEL_BUILDER(
- Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
- MultiDeviceIteratorGetNextFromShardOp);
-
-class MultiDeviceIteratorToStringHandleOp : public OpKernel {
- public:
- explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& resource_handle_t = ctx->input(0);
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
- errors::InvalidArgument("resource_handle must be a scalar"));
-
- // Validate that the handle corresponds to a real resource, and
- // that it is an MultiDeviceIterator.
- MultiDeviceIterator* resource;
- OP_REQUIRES_OK(ctx,
- LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
- resource->Unref();
-
- Tensor* string_handle_t;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({}), &string_handle_t));
- string_handle_t->scalar<string>()() =
- resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
- }
-};
-
-REGISTER_KERNEL_BUILDER(
- Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
- MultiDeviceIteratorToStringHandleOp);
-
-class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
- public:
- explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- OP_REQUIRES(
- ctx,
- output_types_.empty() || output_shapes_.empty() ||
- output_types_.size() == output_shapes_.size(),
- errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
- "are set, they must have the same length."));
- }
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& string_handle_t = ctx->input(0);
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
- errors::InvalidArgument("string_handle must be a scalar"));
-
- ResourceHandle resource_handle;
- OP_REQUIRES(
- ctx,
- resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
- errors::InvalidArgument(
- "Could not parse string_handle as a valid ResourceHandle"));
-
- OP_REQUIRES(
- ctx, resource_handle.device() == ctx->device()->attributes().name(),
- errors::InvalidArgument("Attempted create an iterator on device \"",
- ctx->device()->attributes().name(),
- "\" from handle defined on device \"",
- resource_handle.device(), "\""));
-
- // Validate that the handle corresponds to a real resource, and
- // that it is an MultiDeviceIterator.
- MultiDeviceIterator* resource;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
- core::ScopedUnref unref_iterator(resource);
- if (!output_types_.empty()) {
- OP_REQUIRES_OK(ctx,
- VerifyTypesMatch(output_types_, resource->output_types()));
- }
- if (!output_shapes_.empty()) {
- OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
- resource->output_shapes()));
- }
-
- Tensor* resource_handle_t;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
- resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
- }
-
- private:
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
-};
-
-REGISTER_KERNEL_BUILDER(
- Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
- MultiDeviceIteratorFromStringHandleOp);
-
} // namespace
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index ae104d55bd..d1a771f005 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -65,7 +65,13 @@ REGISTER_OP("CSVDataset")
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
// `record_defaults` must be lists of scalars
for (size_t i = 8; i < c->num_inputs(); ++i) {
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused));
+ shape_inference::ShapeHandle v;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
+ return errors::InvalidArgument(
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
+ }
}
return shape_inference::ScalarShape(c);
});
@@ -145,82 +151,6 @@ Resets the FunctionBufferingResource.
function_buffer_resource: The FunctionBufferingResource handle.
)doc");
-REGISTER_OP("MultiDeviceIterator")
- .Output("handle: resource")
- .Attr("devices: list(string) >= 1")
- .Attr("shared_name: string")
- .Attr("container: string")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .Doc(R"doc(
-Creates a MultiDeviceIterator resource.
-
-handle: Handle to the resource created.
-devices: A list of devices the iterator works across.
-shared_name: If non-empty, this resource will be shared under the given name
- across multiple sessions.
-container: If non-empty, this resource is placed in the given container.
- Otherwise, a default container is used.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorInit")
- .Input("dataset: variant")
- .Input("multi_device_iterator: resource")
- .Input("max_buffer_size: int64")
- .Output("incarnation_id: int64")
- .Doc(R"doc(
-Initializes the multi device iterator with the given dataset.
-max_buffer_size: The maximum size of the host side per device buffer to keep.
-incarnation_id: An int64 indicating which incarnation of the MultiDeviceIterator
- is running.
-dataset: Dataset to be iterated upon.
-multi_device_iterator: A MultiDeviceIteratorResource.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
- .Input("multi_device_iterator: resource")
- .Input("shard_num: int32")
- .Input("incarnation_id: int64")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .Doc(R"doc(
-Gets next element for the provided shard number.
-
-multi_device_iterator: A MultiDeviceIterator resource.
-shard_num: Integer representing which shard to fetch data for.
-incarnation_id: Which incarnation of the MultiDeviceIterator is running.
-components: Result of the get_next on the dataset.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorToStringHandle")
- .Input("multi_device_iterator: resource")
- .Output("string_handle: string")
- .Doc(R"doc(
-Produces a string handle for the given MultiDeviceIterator.
-
-multi_device_iterator: A MultiDeviceIterator resource.
-string_handle: A string representing the resource.
-)doc");
-
-REGISTER_OP("MultiDeviceIteratorFromStringHandle")
- .Input("string_handle: string")
- .Output("multi_device_iterator: resource")
- .Attr("output_types: list(type) >= 0 = []")
- .Attr("output_shapes: list(shape) >= 0 = []")
- .Doc(R"doc(
-Generates a MultiDeviceIterator resource from its provided string handle.
-
-string_handle: String representing the resource.
-multi_device_iterator: A MultiDeviceIterator resource.
-output_types: The type list for the return values.
-output_shapes: The list of shapes being produced.
-)doc");
-
REGISTER_OP("ThreadPoolDataset")
.Input("input_dataset: variant")
.Input("thread_pool: resource")
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 6f0111a2bd..ce52c990ce 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -72,12 +72,13 @@ py_test(
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/eager:context",
"//third_party/py/numpy",
],
)
@@ -189,7 +190,6 @@ py_test(
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
],
)
@@ -276,6 +276,7 @@ py_test(
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
+ "//tensorflow/python:data_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
@@ -324,12 +325,7 @@ cuda_py_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
- tags = [
- "manual",
- "no_oss",
- "no_windows_gpu",
- "notap",
- ],
+ tags = ["no_windows_gpu"],
)
py_test(
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 8e368bf2bc..e2508de9e9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -742,7 +742,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
@@ -813,7 +813,7 @@ class RestructuredDatasetTest(test.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -837,7 +837,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
@@ -879,7 +879,7 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 293be2bd06..48971f2ccc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -531,6 +531,45 @@ class BucketTest(test.TestCase):
self.assertEqual(batches, 15)
+def _element_length_fn(x, y=None):
+ del y
+ return array_ops.shape(x)[0]
+
+
+def _to_sparse_tensor(record):
+ return sparse_tensor.SparseTensor(**record)
+
+
+def _format_record(array, sparse):
+ if sparse:
+ return {
+ "values": array,
+ "indices": [[i] for i in range(len(array))],
+ "dense_shape": (len(array),)
+ }
+ return array
+
+
+def _get_record_type(sparse):
+ if sparse:
+ return {
+ "values": dtypes.int64,
+ "indices": dtypes.int64,
+ "dense_shape": dtypes.int64
+ }
+ return dtypes.int32
+
+
+def _get_record_shape(sparse):
+ if sparse:
+ return {
+ "values": tensor_shape.TensorShape([None,]),
+ "indices": tensor_shape.TensorShape([None, 1]),
+ "dense_shape": tensor_shape.TensorShape([1,])
+ }
+ return tensor_shape.TensorShape([None])
+
+
class BucketBySequenceLength(test.TestCase):
def testBucket(self):
@@ -539,39 +578,58 @@ class BucketBySequenceLength(test.TestCase):
batch_sizes = [10, 8, 4, 2]
lengths = [8, 13, 25, 35]
- def element_gen():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes, lengths):
- for _ in range(batch_size):
- elements.append([1] * length)
- random.shuffle(elements)
- for el in elements:
- yield (el,)
+ def build_dataset(sparse):
+ def _generator():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes, lengths):
+ record_len = length - 1
+ for _ in range(batch_size):
+ elements.append([1] * record_len)
+ record_len = length
+ random.shuffle(elements)
+ for el in elements:
+ yield (_format_record(el, sparse),)
+ dataset = dataset_ops.Dataset.from_generator(
+ _generator,
+ (_get_record_type(sparse),),
+ (_get_record_shape(sparse),))
+ if sparse:
+ dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
+ return dataset
+
+ def _test_bucket_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(
+ grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ batch_sizes,
+ no_padding=no_padding))
+ batch, = dataset.make_one_shot_iterator().get_next()
- element_len = lambda el: array_ops.shape(el)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(4):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- batch_size = batch.shape[0]
- length = batch.shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(lengths), sorted(lengths_val))
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(4):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ shape = batch.dense_shape if no_padding else batch.shape
+ batch_size = shape[0]
+ length = shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ sum_check = batch.values.sum() if no_padding else batch.sum()
+ self.assertEqual(sum_check, batch_size * length - 1)
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+ for no_padding in (True, False):
+ _test_bucket_by_padding(no_padding)
def testPadToBoundary(self):
@@ -657,28 +715,108 @@ class BucketBySequenceLength(test.TestCase):
def testTupleElements(self):
- def elements_gen():
- text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
- label = [1, 2, 1, 2]
- for x, y in zip(text, label):
- yield (x, y)
-
- def element_length_fn(x, y):
- del y
- return array_ops.shape(x)[0]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator=elements_gen,
- output_shapes=(tensor_shape.TensorShape([None]),
- tensor_shape.TensorShape([])),
- output_types=(dtypes.int32, dtypes.int32))
+ def build_dataset(sparse):
+ def _generator():
+ text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+ label = [1, 2, 1, 2]
+ for x, y in zip(text, label):
+ yield (_format_record(x, sparse), y)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=_generator,
+ output_types=(_get_record_type(sparse), dtypes.int32),
+ output_shapes=(_get_record_shape(sparse),
+ tensor_shape.TensorShape([])))
+ if sparse:
+ dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
+ return dataset
+
+ def _test_tuple_elements_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=_element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8],
+ no_padding=no_padding))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
+ for no_padding in (True, False):
+ _test_tuple_elements_by_padding(no_padding)
+
+ def testBucketSparse(self):
+ """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+ Test runs on following dataset:
+ [
+ [0],
+ [0, 1],
+ [0, 1, 2]
+ ...
+ [0, ..., max_len - 1]
+ ]
+ Sequences are bucketed by length and batched with
+ `batch_size` < `bucket_size`.
+ """
+
+ min_len = 0
+ max_len = 100
+ batch_size = 7
+ bucket_size = 10
+
+ def _build_dataset():
+ input_data = [range(i+1) for i in range(min_len, max_len)]
+ def generator_fn():
+ for record in input_data:
+ yield _format_record(record, sparse=True)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=generator_fn,
+ output_types=_get_record_type(sparse=True))
+ dataset = dataset.map(_to_sparse_tensor)
+ return dataset
+
+ def _compute_expected_batches():
+ """Computes expected batch outputs and stores in a set."""
+ all_expected_sparse_tensors = set()
+ for bucket_start_len in range(min_len, max_len, bucket_size):
+ for batch_offset in range(0, bucket_size, batch_size):
+ batch_start_len = bucket_start_len + batch_offset
+ batch_end_len = min(batch_start_len + batch_size,
+ bucket_start_len + bucket_size)
+ expected_indices = []
+ expected_values = []
+ for length in range(batch_start_len, batch_end_len):
+ for val in range(length + 1):
+ expected_indices.append((length - batch_start_len, val))
+ expected_values.append(val)
+ expected_sprs_tensor = (tuple(expected_indices),
+ tuple(expected_values))
+ all_expected_sparse_tensors.add(expected_sprs_tensor)
+ return all_expected_sparse_tensors
+
+ def _compute_batches(dataset):
+ """Computes actual batch outputs of dataset and stores in a set."""
+ batch = dataset.make_one_shot_iterator().get_next()
+ all_sparse_tensors = set()
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ output = sess.run(batch)
+ sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+ tuple(output.values))
+ all_sparse_tensors.add(sprs_tensor)
+ return all_sparse_tensors
+
+ dataset = _build_dataset()
+ boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
dataset = dataset.apply(grouping.bucket_by_sequence_length(
- element_length_func=element_length_fn,
- bucket_batch_sizes=[2, 2, 2],
- bucket_boundaries=[0, 8]))
- shapes = dataset.output_shapes
- self.assertEqual([None, None], shapes[0].as_list())
- self.assertEqual([None], shapes[1].as_list())
+ _element_length_fn,
+ boundaries,
+ [batch_size] * (len(boundaries) + 1),
+ no_padding=True))
+ batches = _compute_batches(dataset)
+ expected_batches = _compute_expected_batches()
+ self.assertEqual(batches, expected_batches)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index 63bffd023f..f8e74e4583 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -31,38 +31,49 @@ from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class CsvDatasetOpTest(test.TestCase):
- def _assert_datasets_equal(self, g, ds1, ds2):
+ def _get_next(self, dataset):
+ # Returns a no argument function whose result is fed to self.evaluate to
+ # yield the next element
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ get_next = it.get_next()
+ return lambda: get_next
+
+ def _assert_datasets_equal(self, ds1, ds2):
assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
'%s') % (ds1.output_shapes,
ds2.output_shapes)
assert ds1.output_types == ds2.output_types
assert ds1.output_classes == ds2.output_classes
- next1 = ds1.make_one_shot_iterator().get_next()
- next2 = ds2.make_one_shot_iterator().get_next()
- with self.session(graph=g) as sess:
- # Run through datasets and check that outputs match, or errors match.
- while True:
- try:
- op1 = sess.run(next1)
- except (errors.OutOfRangeError, ValueError) as e:
- # If op1 throws an exception, check that op2 throws same exception.
- with self.assertRaises(type(e)):
- sess.run(next2)
- break
- op2 = sess.run(next2)
- self.assertAllEqual(op1, op2)
+ next1 = self._get_next(ds1)
+ next2 = self._get_next(ds2)
+ # Run through datasets and check that outputs match, or errors match.
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except (errors.OutOfRangeError, ValueError) as e:
+ # If op1 throws an exception, check that op2 throws same exception.
+ with self.assertRaises(type(e)):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+ self.assertAllEqual(op1, op2)
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@@ -95,33 +106,32 @@ class CsvDatasetOpTest(test.TestCase):
def _test_by_comparison(self, inputs, **kwargs):
"""Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
- with ops.Graph().as_default() as g:
- dataset_actual, dataset_expected = self._make_test_datasets(
- inputs, **kwargs)
- self._assert_datasets_equal(g, dataset_actual, dataset_expected)
+ dataset_actual, dataset_expected = self._make_test_datasets(
+ inputs, **kwargs)
+ self._assert_datasets_equal(dataset_actual, dataset_expected)
def _verify_output_or_err(self,
- sess,
dataset,
expected_output=None,
expected_err_re=None):
- nxt = dataset.make_one_shot_iterator().get_next()
if expected_err_re is None:
# Verify that output is expected, without errors
+ nxt = self._get_next(dataset)
expected_output = [[
v.encode('utf-8') if isinstance(v, str) else v for v in op
] for op in expected_output]
for value in expected_output:
- op = sess.run(nxt)
+ op = self.evaluate(nxt())
self.assertAllEqual(op, value)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(nxt)
+ self.evaluate(nxt())
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
+ nxt = self._get_next(dataset)
while True:
try:
- sess.run(nxt)
+ self.evaluate(nxt())
except errors.OutOfRangeError:
break
@@ -137,11 +147,8 @@ class CsvDatasetOpTest(test.TestCase):
# Convert str type because py3 tf strings are bytestrings
filenames = self._setup_files(inputs, linebreak, compression_type)
kwargs['compression_type'] = compression_type
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, **kwargs)
- self._verify_output_or_err(sess, dataset, expected_output,
- expected_err_re)
+ dataset = readers.CsvDataset(filenames, **kwargs)
+ self._verify_output_or_err(dataset, expected_output, expected_err_re)
def testCsvDataset_requiredFields(self):
record_defaults = [[]] * 4
@@ -191,21 +198,17 @@ class CsvDatasetOpTest(test.TestCase):
record_defaults = [['']] * 3
inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
filenames = self._setup_files(inputs)
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
record_defaults = [['']] * 3
inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
filenames = self._setup_files(inputs)
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']])
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
record_defaults = [['']] * 3
@@ -351,10 +354,9 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,,3,4', '5,6,,8']]
ds_actual, ds_expected = self._make_test_datasets(
inputs, record_defaults=record_defaults)
- with ops.Graph().as_default() as g:
- self._assert_datasets_equal(g,
- ds_actual.repeat(5).prefetch(1),
- ds_expected.repeat(5).prefetch(1))
+ self._assert_datasets_equal(
+ ds_actual.repeat(5).prefetch(1),
+ ds_expected.repeat(5).prefetch(1))
def testCsvDataset_withTypeDefaults(self):
# Testing using dtypes as record_defaults for required fields
@@ -373,13 +375,11 @@ class CsvDatasetOpTest(test.TestCase):
]]
file_path = self._setup_files(data)
- with ops.Graph().as_default() as g:
- ds = readers.make_csv_dataset(
- file_path, batch_size=1, shuffle=False, num_epochs=1)
- next_batch = ds.make_one_shot_iterator().get_next()
+ ds = readers.make_csv_dataset(
+ file_path, batch_size=1, shuffle=False, num_epochs=1)
+ nxt = self._get_next(ds)
- with self.session(graph=g) as sess:
- result = list(sess.run(next_batch).values())
+ result = list(self.evaluate(nxt()).values())
self.assertEqual(result, sorted(result))
@@ -542,6 +542,29 @@ class CsvDatasetOpTest(test.TestCase):
compression_type='ZLIB',
record_defaults=record_defaults)
+ def testCsvDataset_withScalarDefaults(self):
+ record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+ def testCsvDataset_with2DDefaults(self):
+ record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+
+ if context.executing_eagerly():
+ err_spec = errors.InvalidArgumentError, (
+ 'Each record default should be at '
+ 'most rank 1.')
+ else:
+ err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
+
+ with self.assertRaisesWithPredicateMatch(*err_spec):
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
class CsvDatasetBenchmark(test.Benchmark):
"""Benchmarks for the various ways of creating a dataset from CSV files.
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 61567bc8d7..25aea0393f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -115,7 +116,7 @@ class MapDefunTest(test.TestCase):
elems2 = array_ops.placeholder(dtypes.int32)
result = map_defun.map_defun(fn, [elems1, elems2],
[dtypes.int32, dtypes.int32], [(), ()])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
"All inputs must have the same dimension 0."):
@@ -207,6 +208,31 @@ class MapDefunTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(r, feed_dict={p: 0})
+ def _assert_op_cancelled(self, sess, map_defun_op):
+ with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
+ sess.run(map_defun_op)
+
+ def testMapDefunWithParentCancellation(self):
+ # Checks that a cancellation of the parent graph is threaded through to
+ # MapDefunOp correctly.
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ del x
+ queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
+ # Blocking
+ return queue.dequeue_many(5)
+
+ c = constant_op.constant([1, 2, 3, 4, 5])
+ map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
+
+ with self.cached_session() as sess:
+ thread = self.checkedThread(
+ self._assert_op_cancelled, args=(sess, map_defun_op))
+ thread.start()
+ time.sleep(0.1)
+ sess.close()
+ thread.join()
+
class MapDefunBenchmark(test.Benchmark):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index 459bdf66f3..a2fc244ced 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -74,6 +74,58 @@ py_test(
)
py_test(
+ name = "map_parallelization_test",
+ size = "small",
+ srcs = ["map_parallelization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "model_dataset_op_test",
+ size = "medium",
+ srcs = ["model_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "noop_elimination_test",
+ size = "small",
+ srcs = ["noop_elimination_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:batching",
+ "//tensorflow/contrib/data/python/ops:interleave_ops",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "optimize_dataset_op_test",
size = "small",
srcs = ["optimize_dataset_op_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
index bd7b50b902..d10da80442 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -31,7 +31,7 @@ class AssertNextDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(0, sess.run(get_next))
def testAssertNextInvalid(self):
@@ -40,7 +40,7 @@ class AssertNextDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Asserted Whoops transformation at offset 0 but encountered "
@@ -53,7 +53,7 @@ class AssertNextDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Asserted next 2 transformations but encountered only 1."):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index dde115925e..e75edf6086 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -200,7 +200,7 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
optimization.optimize(["filter_fusion"]))
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for x in range(5):
r = map_function(x)
filtered = False
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
new file mode 100644
index 0000000000..dd547db086
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
@@ -0,0 +1,84 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class MapParallelizationTest(test.TestCase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+
+ def assert_greater(x):
+ assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
+ with ops.control_dependencies([assert_op]):
+ return x
+
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=0,
+ maxval=10,
+ dtype=dtypes.int64,
+ seed=42)
+
+ def assert_with_random(x):
+ x = assert_greater(x)
+ return random(x)
+
+ return (("Identity", identity, True), ("Increment", increment, True),
+ ("AssertGreater", assert_greater, True), ("Random", random, False),
+ ("AssertWithRandom", assert_with_random, False))
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testMapParallelization(self, function, should_optimize):
+ next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(next_nodes)).map(function).apply(
+ optimization.optimize(["map_parallelization"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ # No need to run the pipeline if it was not optimized. Also the results
+ # might be hard to check because of random.
+ if not should_optimize:
+ return
+ r = function(x)
+ self.assertAllEqual(r, result)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index e2c9bc82df..5b493f44c9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -173,16 +173,6 @@ class MapVectorizationBenchmark(test.Benchmark):
self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
return median_time
- def benchmark_CheapFns(self):
-
- input_sizes = [(10, 10, 3), (10, 100, 300)]
- batch_size = 1000
- for input_size in input_sizes:
- input_dataset = dataset_ops.Dataset.from_tensor_slices(
- (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
- for map_fn, str_id in self._get_known_cheap_fns():
- self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
-
def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
num_elems = np.prod(input_size)
name_template = "{}__batch_size_{}_input_size_{}_{}"
@@ -205,14 +195,28 @@ class MapVectorizationBenchmark(test.Benchmark):
"Speedup: {}\n".format(batch_size, input_size, str_id,
(unoptimized_time / optimized_time)))
- def _get_known_cheap_fns(self):
- return [
- (lambda *args: [array_ops.identity(x) for x in args], "identity"),
- (lambda *args: [x + 1 for x in args], "add_const"),
- (lambda *args: args[0], "select"),
- (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args],
- "cast"),
- ]
+ # Known cheap functions
+ def benchmarkIdentity(self):
+ self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
+ "identity")
+
+ def benchmarkAddConst(self):
+ self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+
+ def benchmarkSelect(self):
+ self._benchmark_helper(lambda *args: args[0], "select")
+
+ def benchmarkCast(self):
+ self._benchmark_helper(
+ lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
+
+ def _benchmark_helper(self, map_fn, str_id):
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ batch_size = 1000
+ for input_size in input_sizes:
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+ self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
new file mode 100644
index 0000000000..3b62a7e468
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -0,0 +1,182 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ModelDatasetTest(test.TestCase):
+
+ def testModelMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelParallelMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(
+ math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelMapAndBatch(self):
+ batch_size = 16
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=optimization.AUTOTUNE,
+ batch_size=batch_size))
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(10):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelParallelInterleave(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset,
+ cycle_length=10,
+ num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelNested(self):
+ k = 1024 * 1024
+ a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
+ b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
+ c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
+ dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
+
+ def f1(a, b, c):
+ x, y = a
+ return math_ops.matmul(x, y), b, c
+
+ def f2(a, b, c):
+ x, y = b
+ return a, math_ops.matmul(x, y), c
+
+ def f3(a, b, c):
+ x, y = c
+ return a, b, math_ops.matmul(x, y)
+
+ dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=2)
+
+ dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=2)
+
+ dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
new file mode 100644
index 0000000000..507feda3ad
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class NoopEliminationTest(test.TestCase):
+
+ def testNoopElimination(self):
+ a = constant_op.constant(1, dtype=dtypes.int64)
+ b = constant_op.constant(2, dtype=dtypes.int64)
+ some_tensor = math_ops.mul(a, b)
+
+ dataset = dataset_ops.Dataset.range(5)
+ dataset = dataset.apply(
+ optimization.assert_next(
+ ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
+ dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
+ 0).repeat(1).prefetch(0)
+ dataset = dataset.apply(optimization.optimize(["noop_elimination"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ self.assertAllEqual(result, x)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index 909da5aee0..a3fb824ce9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -38,7 +38,7 @@ class OptimizeDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -51,7 +51,7 @@ class OptimizeDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -64,7 +64,7 @@ class OptimizeDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -76,7 +76,7 @@ class OptimizeDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(get_next)
def testOptimizationLargeInputFromTensor(self):
@@ -87,7 +87,7 @@ class OptimizeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
sess.run(get_next)
@@ -99,7 +99,7 @@ class OptimizeDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 0166ba0d44..33a64ea767 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -31,7 +31,6 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@@ -944,155 +943,5 @@ class CopyToDeviceTest(test.TestCase):
sess.run(elem_value_t)
-class MultiDeviceIteratorTest(test.TestCase):
-
- def testBasic(self):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testOneOnSameDevice(self):
- with ops.device("/cpu:0"):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:0", "/cpu:1"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testRepeatDevices(self):
- with ops.device("/cpu:0"):
- dataset = dataset_ops.Dataset.range(20)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
- elements = multi_device_iterator.get_next()
- elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 20, 4):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- self.assertEqual(i + 2, sess.run(elem_on_3))
- self.assertEqual(i + 3, sess.run(elem_on_4))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
- sess.run(elem_on_3)
- sess.run(elem_on_4)
-
- def testNotFullyDivisible(self):
- dataset = dataset_ops.Dataset.range(9)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 8, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- self.assertEqual(8, sess.run(elem_on_1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testUneven(self):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- for i in range(0, 10, 2):
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testMultipleInitializations(self):
- with ops.device("/cpu:0"):
- epoch = array_ops.placeholder(dtypes.int64, shape=[])
- dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
- dataset2 = dataset_ops.Dataset.range(1000)
- dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
- init_op = multi_device_iterator.initializer
-
- config = config_pb2.ConfigProto(device_count={"CPU": 3})
- with self.test_session(config=config) as sess:
- for i in range(1000):
- sess.run(init_op, feed_dict={epoch: i})
- self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
-
- def testBasicGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- with compat.forward_compatibility_horizon(2018, 8, 4):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/gpu:0"])
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
- def testUnevenGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- with compat.forward_compatibility_horizon(2018, 8, 4):
- dataset = dataset_ops.Dataset.range(10)
- multi_device_iterator = prefetching_ops.MultiDeviceIterator(
- dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)
- elem_on_1, elem_on_2 = multi_device_iterator.get_next()
-
- config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
- with self.test_session(config=config) as sess:
- sess.run(multi_device_iterator.initializer)
- for i in range(0, 10, 2):
- self.assertEqual(i, sess.run(elem_on_1))
- for i in range(0, 10, 2):
- self.assertEqual(i + 1, sess.run(elem_on_2))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index e25570c5ad..be8ae5e955 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -25,6 +25,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -40,7 +41,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
expected_sum = 0.0
for i in range(100):
@@ -65,7 +66,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
@@ -84,7 +85,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertAllEqual(
@@ -92,6 +93,8 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
float(i + 1))
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
0, 1)
with self.assertRaises(errors.OutOfRangeError):
@@ -100,6 +103,53 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
+ def testPrefetchBufferScalars(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ 0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasScalarValue(summary_str,
+ "Prefetch::buffer_capacity", 0)
+ self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
+ 0)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testFilteredElementsStats(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(101).filter(
+ lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(34):
+ self.assertEqual(i * 3, sess.run(next_element))
+ if i is not 0:
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", 67.0)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", 34.0)
+
def testReinitialize(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@@ -109,7 +159,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for j in range(5):
sess.run(iterator.initializer)
for i in range(100):
@@ -127,7 +177,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
@@ -144,7 +194,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
@@ -168,7 +218,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
@@ -188,7 +238,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
next_element = iterator_0.get_next() + iterator_1.get_next()
summary_t = stats_aggregator.get_summary()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, sess.run(next_element))
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 2f5a44408f..b1b4c23510 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -25,6 +25,14 @@ from tensorflow.python.platform import test
class StatsDatasetTestBase(test.TestCase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
+ def _assertSummaryContains(self, summary_str, tag):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
def _assertSummaryHasCount(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
@@ -52,3 +60,12 @@ class StatsDatasetTestBase(test.TestCase):
self.assertEqual(expected_value, value.histo.sum)
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.simple_value)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 6eaa0b1959..8b7b3ac0f7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -89,13 +89,14 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
return dataset_ops.Dataset.zip(
tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
- dataset = self._structuredDataset(structure, shape, dtype).apply(
+ dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).flat_map(fn)
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
expected = sess.run(self._structuredElement(structure, shape, dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
+ for _ in range(5):
+ actual = sess.run(get_next)
+ self._assertEqual(expected, actual)
@parameterized.named_parameters(
("1", None, np.int32([]), dtypes.bool),
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 4b45cc7e36..a14781cd93 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -80,6 +80,7 @@ py_library(
":batching",
":gen_dataset_ops",
":interleave_ops",
+ ":optimization",
":parsing_ops",
":shuffle_ops",
"//tensorflow/python:constant_op",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 367c159dc5..7a0f221284 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -345,12 +345,12 @@ def _padded_batch_sparse_window(dataset, padded_shape):
dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
-class _UnbatchDataset(dataset_ops.Dataset):
+class _UnbatchDataset(dataset_ops.UnaryDataset):
"""A dataset that splits the elements of its input into multiple elements."""
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
- super(_UnbatchDataset, self).__init__()
+ super(_UnbatchDataset, self).__init__(input_dataset)
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
@@ -514,12 +514,12 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class _DenseToSparseBatchDataset(dataset_ops.Dataset):
+class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
def __init__(self, input_dataset, batch_size, row_shape):
"""See `Dataset.dense_to_sparse_batch()` for more details."""
- super(_DenseToSparseBatchDataset, self).__init__()
+ super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
if not isinstance(input_dataset.output_types, dtypes.DType):
raise TypeError("DenseToSparseDataset requires an input whose elements "
"have a single component, whereas the input has %r." %
@@ -548,7 +548,7 @@ class _DenseToSparseBatchDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _RestructuredDataset(dataset_ops.Dataset):
+class _RestructuredDataset(dataset_ops.UnaryDataset):
"""An internal helper for changing the structure and shape of a dataset."""
def __init__(self,
@@ -583,7 +583,7 @@ class _RestructuredDataset(dataset_ops.Dataset):
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
- super(_RestructuredDataset, self).__init__()
+ super(_RestructuredDataset, self).__init__(dataset)
self._input_dataset = dataset
if not allow_unsafe_cast:
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index b4a7521e08..615dbcabd4 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -51,12 +51,12 @@ def ignore_errors():
return _apply_fn
-class _IgnoreErrorsDataset(dataset_ops.Dataset):
+class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that silently ignores errors when computing its input."""
def __init__(self, input_dataset):
"""See `Dataset.ignore_errors()` for details."""
- super(_IgnoreErrorsDataset, self).__init__()
+ super(_IgnoreErrorsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 6edc1d7990..7cae33beb3 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -124,7 +124,8 @@ def bucket_by_sequence_length(element_length_func,
bucket_batch_sizes,
padded_shapes=None,
padding_values=None,
- pad_to_bucket_boundary=False):
+ pad_to_bucket_boundary=False,
+ no_padding=False):
"""A transformation that buckets elements in a `Dataset` by length.
Elements of the `Dataset` are grouped together by length and then are padded
@@ -152,6 +153,8 @@ def bucket_by_sequence_length(element_length_func,
unknown size to bucket boundary minus 1 (i.e., the maximum length in each
bucket), and caller must ensure that the source `Dataset` does not contain
any elements with length longer than `max(bucket_boundaries)`.
+ no_padding: `bool`, indicates whether to pad the batch features (features
+ need to be either of type `tf.SparseTensor` or of same shape).
Returns:
A `Dataset` transformation function, which can be passed to
@@ -199,7 +202,9 @@ def bucket_by_sequence_length(element_length_func,
def batching_fn(bucket_id, grouped_dataset):
"""Batch elements in dataset."""
- batch_size = batch_sizes[bucket_id]
+ batch_size = window_size_fn(bucket_id)
+ if no_padding:
+ return grouped_dataset.batch(batch_size)
none_filler = None
if pad_to_bucket_boundary:
err_msg = ("When pad_to_bucket_boundary=True, elements must have "
@@ -250,6 +255,7 @@ def _map_x_dataset(map_func):
return _apply_fn
+# TODO(b/115382007) Remove this once canned reducers move to core.
def window_dataset(window_size):
"""A transformation that creates window datasets from the input dataset.
@@ -266,17 +272,22 @@ def window_dataset(window_size):
"""
def _apply_fn(dataset):
- return _WindowDataset(dataset, window_size)
+ return dataset_ops.WindowDataset(
+ dataset,
+ size=window_size,
+ shift=window_size,
+ stride=1,
+ drop_remainder=False)
return _apply_fn
-class _GroupByReducerDataset(dataset_ops.Dataset):
+class _GroupByReducerDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that groups its input and performs a reduction."""
def __init__(self, input_dataset, key_func, reducer):
"""See `group_by_reducer()` for details."""
- super(_GroupByReducerDataset, self).__init__()
+ super(_GroupByReducerDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
@@ -405,12 +416,12 @@ class _GroupByReducerDataset(dataset_ops.Dataset):
**dataset_ops.flat_structure(self))
-class _GroupByWindowDataset(dataset_ops.Dataset):
+class _GroupByWindowDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
"""See `group_by_window()` for details."""
- super(_GroupByWindowDataset, self).__init__()
+ super(_GroupByWindowDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
@@ -514,12 +525,12 @@ class Reducer(object):
return self._finalize_func
-class _MapXDataset(dataset_ops.Dataset):
+class _MapXDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func):
"""See `map_x_dataset()` for details."""
- super(_MapXDataset, self).__init__()
+ super(_MapXDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = dataset_ops.StructuredFunctionWrapper(
@@ -551,46 +562,3 @@ class _MapXDataset(dataset_ops.Dataset):
@property
def output_types(self):
return self._output_types
-
-
-class _WindowDataset(dataset_ops.Dataset):
- """A dataset that creates window datasets from the input elements."""
-
- def __init__(self, input_dataset, window_size):
- """See `window_dataset()` for more details."""
- super(_WindowDataset, self).__init__()
- self._input_dataset = input_dataset
- self._window_size = ops.convert_to_tensor(
- window_size, dtype=dtypes.int64, name="window_size")
- self._output_classes = nest.pack_sequence_as(
- input_dataset.output_classes,
- [
- dataset_ops._NestedDatasetComponent( # pylint: disable=protected-access
- output_classes=output_class,
- output_shapes=output_shape,
- output_types=output_type)
- for output_class, output_shape, output_type in zip(
- nest.flatten(input_dataset.output_classes),
- nest.flatten(input_dataset.output_shapes),
- nest.flatten(input_dataset.output_types))
- ])
- self._output_shapes = self._output_classes
- self._output_types = self._output_classes
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.window_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._window_size,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
index a0932b4081..cc76ab0850 100644
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
@@ -171,3 +171,6 @@ class IdentityIndexedDataset(IndexedDataset):
def _as_variant_tensor(self):
return gen_dataset_ops.identity_indexed_dataset(self._size)
+
+ def _inputs(self):
+ return []
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 92d4251a86..bfa3fdf543 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -173,6 +173,9 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
**dataset_ops.flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._selector_input] + self._data_inputs
+
@property
def output_classes(self):
return self._data_inputs[0].output_classes
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index fa1b851ad7..3eb172acd5 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -24,6 +24,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+# A constant that can be used to enable auto-tuning.
+AUTOTUNE = -1
+
# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
# account for indexing) and transformation sequence.
@@ -46,6 +49,21 @@ def assert_next(transformations):
return _apply_fn
+def model():
+ """A transformation that models performance.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _ModelDataset(dataset)
+
+ return _apply_fn
+
+
def optimize(optimizations=None):
"""A transformation that applies optimizations.
@@ -66,12 +84,12 @@ def optimize(optimizations=None):
return _apply_fn
-class _AssertNextDataset(dataset_ops.Dataset):
+class _AssertNextDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that asserts which transformations happen next."""
def __init__(self, input_dataset, transformations):
"""See `assert_next()` for details."""
- super(_AssertNextDataset, self).__init__()
+ super(_AssertNextDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if transformations is None:
raise ValueError("At least one transformation should be specified")
@@ -97,12 +115,38 @@ class _AssertNextDataset(dataset_ops.Dataset):
return self._input_dataset.output_types
-class _OptimizeDataset(dataset_ops.Dataset):
+class _ModelDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and models performance."""
+
+ def __init__(self, input_dataset):
+ """See `optimize()` for details."""
+ super(_ModelDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.model_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _OptimizeDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and applies optimizations."""
def __init__(self, input_dataset, optimizations):
"""See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__()
+ super(_OptimizeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if optimizations is None:
optimizations = []
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index 2701605e64..cfbba701b0 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -26,11 +26,11 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
-class _ParseExampleDataset(dataset_ops.Dataset):
+class _ParseExampleDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that parses `example` dataset into a `dict` dataset."""
def __init__(self, input_dataset, features, num_parallel_calls):
- super(_ParseExampleDataset, self).__init__()
+ super(_ParseExampleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if not all(types == dtypes.string
for types in nest.flatten(input_dataset.output_types)):
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 5222011d04..f994425304 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -262,10 +262,11 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
# pylint: enable=protected-access
-class _PrefetchToDeviceDataset(dataset_ops.Dataset):
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` whose iterator prefetches elements to another device."""
def __init__(self, input_dataset, device, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._device = device
self._buffer_size = buffer_size if buffer_size is not None else 1
@@ -374,7 +375,7 @@ def copy_to_device(target_device, source_device="/cpu:0"):
# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
# all inputs to the Op are in host memory, thereby avoiding some unnecessary
# Sends and Recvs.
-class _CopyToDeviceDataset(dataset_ops.Dataset):
+class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that copies elements to another device."""
def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
@@ -385,6 +386,7 @@ class _CopyToDeviceDataset(dataset_ops.Dataset):
target_device: The name of the device to which elements would be copied.
source_device: Device where input_dataset would be placed.
"""
+ super(_CopyToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._target_device = target_device
spec = framework_device.DeviceSpec().from_string(self._target_device)
@@ -612,6 +614,10 @@ class _PerDeviceGenerator(dataset_ops.Dataset):
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
+ def _inputs(self):
+ # TODO(b/116506223): Determine which datasets should be used as inputs here.
+ return []
+
@property
def output_types(self):
return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index e670c4c835..344a0763c8 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
-class RandomDataset(dataset_ops.Dataset):
+class RandomDataset(dataset_ops.DatasetSource):
"""A `Dataset` of pseudorandom values."""
def __init__(self, seed=None):
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 4c466781f7..d9d06e2703 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import parsing_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
@@ -214,18 +215,17 @@ def _maybe_shuffle_and_repeat(
return dataset
-def make_tf_record_dataset(
- file_pattern,
- batch_size,
- parser_fn=None,
- num_epochs=None,
- shuffle=True,
- shuffle_buffer_size=None,
- shuffle_seed=None,
- prefetch_buffer_size=None,
- num_parallel_reads=None,
- num_parallel_parser_calls=None,
- drop_final_batch=False):
+def make_tf_record_dataset(file_pattern,
+ batch_size,
+ parser_fn=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=None,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=None,
+ num_parallel_parser_calls=None,
+ drop_final_batch=False):
"""Reads and optionally parses TFRecord files into a dataset.
Provides common functionality such as batching, optional parsing, shuffling,
@@ -300,8 +300,6 @@ def make_tf_record_dataset(
parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
drop_remainder=drop_final_batch))
- if prefetch_buffer_size is None:
- prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE
if prefetch_buffer_size == 0:
return dataset
else:
@@ -323,7 +321,7 @@ def make_csv_dataset(
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
- prefetch_buffer_size=1,
+ prefetch_buffer_size=optimization.AUTOTUNE,
num_parallel_reads=1,
sloppy=False,
num_rows_for_inference=100,
@@ -386,9 +384,10 @@ def make_csv_dataset(
shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
ensures better shuffling, but increases memory usage and startup time.
shuffle_seed: Randomization seed to use for shuffling.
- prefetch_buffer_size: An int specifying the number of feature batches to
- prefetch for performance improvement. Recommended value is the number of
- batches consumed per training step.
+ prefetch_buffer_size: An int specifying the number of feature
+ batches to prefetch for performance improvement. Recommended value is the
+ number of batches consumed per training step. Defaults to auto-tune.
+
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
sloppy: If `True`, reading performance will be improved at
@@ -509,7 +508,7 @@ def make_csv_dataset(
_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
-class CsvDataset(dataset_ops.Dataset):
+class CsvDataset(dataset_ops.DatasetSource):
"""A Dataset comprising lines from one or more CSV files."""
def __init__(self,
@@ -666,7 +665,7 @@ def make_batched_features_dataset(file_pattern,
shuffle=True,
shuffle_buffer_size=10000,
shuffle_seed=None,
- prefetch_buffer_size=1,
+ prefetch_buffer_size=optimization.AUTOTUNE,
reader_num_threads=1,
parser_num_threads=2,
sloppy_ordering=False,
@@ -739,7 +738,7 @@ def make_batched_features_dataset(file_pattern,
shuffle_seed: Randomization seed to use for shuffling.
prefetch_buffer_size: Number of feature batches to prefetch in order to
improve performance. Recommended value is the number of batches consumed
- per training step (default is 1).
+ per training step. Defaults to auto-tune.
reader_num_threads: Number of threads used to read `Example` records. If >1,
the results will be interleaved.
parser_num_threads: Number of threads to use for parsing `Example` tensors
@@ -925,7 +924,7 @@ def _get_file_names(file_pattern, shuffle):
return file_names
-class SqlDataset(dataset_ops.Dataset):
+class SqlDataset(dataset_ops.DatasetSource):
"""A `Dataset` consisting of the results from a SQL query."""
def __init__(self, driver_name, data_source_name, query, output_types):
@@ -986,7 +985,7 @@ class SqlDataset(dataset_ops.Dataset):
return self._output_types
-class LMDBDataset(dataset_ops.Dataset):
+class LMDBDataset(dataset_ops.DatasetSource):
"""A LMDB Dataset that reads the lmdb file."""
def __init__(self, filenames):
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 6b002b4a53..c52582cd35 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -27,12 +27,12 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_dataset_ops
-class _ScanDataset(dataset_ops.Dataset):
+class _ScanDataset(dataset_ops.UnaryDataset):
"""A dataset that scans a function across its input."""
def __init__(self, input_dataset, initial_state, scan_func):
"""See `scan()` for details."""
- super(_ScanDataset, self).__init__()
+ super(_ScanDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
with ops.name_scope("initial_state"):
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 4356721704..985d1d87d0 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -25,16 +25,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-class _ShuffleAndRepeatDataset(dataset_ops.Dataset):
+class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that fuses `shuffle` and `repeat`."""
- def __init__(self,
- input_dataset,
- buffer_size,
- count=None,
- seed=None):
- """See `Dataset.map()` for details."""
- super(_ShuffleAndRepeatDataset, self).__init__()
+ def __init__(self, input_dataset, buffer_size, count=None, seed=None):
+ super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py
index 8025dcdd16..bcc383587c 100644
--- a/tensorflow/contrib/data/python/ops/sliding.py
+++ b/tensorflow/contrib/data/python/ops/sliding.py
@@ -26,12 +26,12 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
-class _SlideDataset(dataset_ops.Dataset):
+class _SlideDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that passes a sliding window over its input."""
def __init__(self, input_dataset, window_size, window_shift, window_stride):
"""See `sliding_window_batch` for details."""
- super(_SlideDataset, self).__init__()
+ super(_SlideDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._window_size = ops.convert_to_tensor(
window_size, dtype=dtypes.int64, name="window_stride")
@@ -67,6 +67,10 @@ class _SlideDataset(dataset_ops.Dataset):
@deprecation.deprecated_args(
None, "stride is deprecated, use window_shift instead", "stride")
+@deprecation.deprecated(
+ None, "Use `tf.data.Dataset.window(size=window_size, shift=window_shift, "
+ "stride=window_stride).flat_map(lambda x: x.batch(window.size))` "
+ "instead.")
def sliding_window_batch(window_size,
stride=None,
window_shift=None,
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 8426228992..bc47c5989d 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -23,34 +23,31 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
To record statistics, use one of the custom transformation functions defined
in this module when defining your `tf.data.Dataset`. All statistics will be
aggregated by the `StatsAggregator` that is associated with a particular
- iterator (see below). For example, to record the total number of bytes
- produced by iterating over a dataset:
+ iterator (see below). For example, to record the latency of producing each
+ element by iterating over a dataset:
```python
dataset = ...
- dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes"))
+ dataset = dataset.apply(stats_ops.latency_stats("total_bytes"))
```
- To associate a `StatsAggregator` with a `tf.data.Iterator` object, use
+ To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
the following pattern:
```python
- dataset = ...
- iterator = dataset.make_one_shot_iterator()
stats_aggregator = stats_ops.StatsAggregator()
- set_op = stats_aggregator.subscribe(iterator)
+ dataset = ...
- with tf.Session() as sess:
- # Running `set_op` will associate `iterator` with `stats_aggregator`.
- sess.run(set_op)
+ # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
+ dataset = dataset.apply(
+ tf.contrib.data.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_one_shot_iterator()
```
To get a protocol buffer summary of the currently aggregated statistics,
@@ -60,6 +57,7 @@ class StatsAggregator(object):
```python
stats_aggregator = stats_ops.StatsAggregator()
+ # ...
stats_summary = stats_aggregator.get_summary()
tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
```
@@ -73,6 +71,7 @@ class StatsAggregator(object):
"""Creates a `StatsAggregator`."""
self._resource = gen_dataset_ops.stats_aggregator_handle()
+ # TODO(b/116314787): Update this/add support for V2 summary API.
def get_summary(self):
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
@@ -85,11 +84,11 @@ class StatsAggregator(object):
return gen_dataset_ops.stats_aggregator_summary(self._resource)
-class _SetStatsAggregatorDataset(dataset_ops.Dataset):
+class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
def __init__(self, input_dataset, stats_aggregator):
- super(_SetStatsAggregatorDataset, self).__init__()
+ super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = stats_aggregator
@@ -112,13 +111,11 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
def set_stats_aggregator(stats_aggregator):
- """Set the given stats_aggregator for aggregating the input dataset stats.
+ """Set the given `stats_aggregator` for aggregating the input dataset stats.
Args:
- stats_aggregator: A `StatsAggregator` object.
+ stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -155,8 +152,6 @@ def bytes_produced_stats(tag):
return _apply_fn
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
def latency_stats(tag):
"""Records the latency of producing each element of the input dataset.
@@ -178,11 +173,11 @@ def latency_stats(tag):
return _apply_fn
-class _StatsDataset(dataset_ops.Dataset):
+class _StatsDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
def __init__(self, input_dataset, op_function, tag):
- super(_StatsDataset, self).__init__()
+ super(_StatsDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._op_function = op_function
self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index dc67accdcf..9d165ad52a 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -61,11 +61,11 @@ class PrivateThreadPool(object):
display_name=display_name)
-class _ThreadPoolDataset(dataset_ops.Dataset):
+class _ThreadPoolDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets a custom threadpool."""
def __init__(self, input_dataset, thread_pool):
- super(_ThreadPoolDataset, self).__init__()
+ super(_ThreadPoolDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._thread_pool = thread_pool
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index e0d606311c..bad67a580d 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -47,12 +47,12 @@ def unique():
return _apply_fn
-class _UniqueDataset(dataset_ops.Dataset):
+class _UniqueDataset(dataset_ops.UnaryDataset):
"""A `Dataset` contains the unique elements from its input."""
def __init__(self, input_dataset):
"""See `unique()` for details."""
- super(_UniqueDataset, self).__init__()
+ super(_UniqueDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
dtypes.string):
diff --git a/tensorflow/contrib/deprecated/summaries_test.py b/tensorflow/contrib/deprecated/summaries_test.py
index 6acf2a6469..4038224a1c 100644
--- a/tensorflow/contrib/deprecated/summaries_test.py
+++ b/tensorflow/contrib/deprecated/summaries_test.py
@@ -27,31 +27,31 @@ from tensorflow.python.platform import test
class DeprecatedSummariesTest(test.TestCase):
def testScalarSummary(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(3)
s = logging_ops.scalar_summary('tag', c)
self.assertEqual(s.op.type, u'ScalarSummary')
def testHistogramSummary(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(3)
s = logging_ops.histogram_summary('tag', c)
self.assertEqual(s.op.type, u'HistogramSummary')
def testImageSummary(self):
- with self.test_session():
+ with self.cached_session():
i = array_ops.ones((5, 4, 4, 3))
s = logging_ops.image_summary('tag', i)
self.assertEqual(s.op.type, u'ImageSummary')
def testAudioSummary(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(3.0)
s = logging_ops.audio_summary('tag', c, sample_rate=8000)
self.assertEqual(s.op.type, u'AudioSummaryV2')
def testMergeSummary(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(3)
a = logging_ops.scalar_summary('a', c)
b = logging_ops.scalar_summary('b', c)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 30e1992c01..91a27f97b7 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -76,7 +76,7 @@ We then compile the Keras model and pass the `MirroredStrategy` object in the
```python
model.compile(loss='mean_squared_error',
optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2),
- distribute=strategy)
+ distribute=distribution)
```
To train the model we call Keras `fit` API using the input dataset that we
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 350f81f60f..823fe6a917 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Prototype of a distributed computation library for TF."""
+"""A distributed computation library for TF.
+
+See [tensorflow/contrib/distribute/README.md](
+https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
+for overview and examples.
+"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 87f76eaa94..7eead6e472 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -22,7 +22,6 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
- ":prefetching_ops_v2",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -30,6 +29,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
@@ -472,11 +472,8 @@ cuda_py_test(
"//tensorflow/python:summary",
],
tags = [
- "manual",
"multi_and_single_gpu",
"no_pip",
- "nogpu",
- "notap",
],
)
@@ -485,7 +482,6 @@ py_library(
srcs = ["single_loss_example.py"],
deps = [
":step_fn",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:layers",
@@ -652,32 +648,6 @@ cuda_py_test(
)
py_library(
- name = "prefetching_ops_v2",
- srcs = ["prefetching_ops_v2.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/data/python/ops:prefetching_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-cuda_py_test(
- name = "prefetching_ops_v2_test",
- srcs = ["prefetching_ops_v2_test.py"],
- additional_deps = [
- ":prefetching_ops_v2",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
-)
-
-py_library(
name = "input_ops",
srcs = ["input_ops.py"],
visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 77079d0df9..c900b41e14 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -143,8 +143,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
def _real_mirrored_creator(devices, *args, **kwargs):
"""Creates one MirroredVariable on the current worker."""
index = {}
+ unique_var_name = ops.get_default_graph().unique_name(
+ kwargs["name"], mark_as_used=False).rstrip("/")
collective_instance_key = self._collective_keys.get_instance_key(
- key_id=kwargs["name"])
+ key_id=unique_var_name)
if "initial_value" not in kwargs:
raise ValueError("Initial value must be specified.")
initial_value = kwargs["initial_value"]
@@ -188,6 +190,10 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
+ if i == 0:
+ actual_var_name = v.name.split(":")[0]
+ assert unique_var_name == actual_var_name, "%r vs %r" % (
+ unique_var_name, actual_var_name)
assert not isinstance(v, values.DistributedVariable)
index[d] = v
return index
@@ -229,8 +235,6 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
if not session_config or not self._cluster_spec:
return
- session_config.isolate_session_state = True
-
assert self._task_type
assert self._task_id is not None
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 36e9761073..33ffbf6abe 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -34,9 +35,14 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+from tensorflow.python.training import training_util
class CollectiveAllReduceStrategyTestBase(
@@ -146,6 +152,56 @@ class CollectiveAllReduceStrategyTestBase(
self.assertLess(error_after, error_before)
return error_after < error_before
+ def _test_complex_model(self, task_type, task_id, num_gpus):
+ d, master_target = self._get_test_object(task_type, task_id, num_gpus)
+
+ def model_fn():
+ """Mnist model with synthetic input."""
+ data_format = 'channels_last'
+ input_shape = [28, 28, 1]
+ l = keras.layers
+ max_pool = l.MaxPooling2D((2, 2), (2, 2),
+ padding='same',
+ data_format=data_format)
+ model = keras.Sequential([
+ l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),
+ l.Conv2D(
+ 32,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=nn.relu), max_pool,
+ l.Conv2D(
+ 64,
+ 5,
+ padding='same',
+ data_format=data_format,
+ activation=nn.relu), max_pool,
+ l.Flatten(),
+ l.Dense(1024, activation=nn.relu),
+ l.Dropout(0.4),
+ l.Dense(10)
+ ])
+ image = random_ops.random_uniform([2, 28, 28])
+ label = random_ops.random_uniform([2, 1], maxval=10, dtype=dtypes.int32)
+ logits = model(image, training=True)
+ loss = losses.sparse_softmax_cross_entropy(labels=label, logits=logits)
+ optimizer = adam.AdamOptimizer(learning_rate=1e-4)
+ train_op = optimizer.minimize(loss,
+ training_util.get_or_create_global_step())
+ return train_op
+
+ with ops.Graph().as_default(), \
+ self.test_session(config=self._sess_config,
+ target=master_target) as sess:
+ with d.scope():
+ train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(d.unwrap(train_op))
+
+ sess.run(variables.global_variables_initializer())
+ sess.run(train_op)
+ return True
+
def _test_variable_initialization(self, task_type, task_id, num_gpus):
distribution, master_target = self._get_test_object(task_type, task_id,
num_gpus)
@@ -206,6 +262,14 @@ class DistributedCollectiveAllReduceStrategyTest(
self._cluster_spec,
num_gpus=num_gpus)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testComplexModel(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+
class DistributedCollectiveAllReduceStrategyTestWithChief(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -236,6 +300,14 @@ class DistributedCollectiveAllReduceStrategyTestWithChief(
self._cluster_spec,
num_gpus=num_gpus)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testComplexModel(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+
class LocalCollectiveAllReduceStrategy(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -246,6 +318,12 @@ class LocalCollectiveAllReduceStrategy(
return
self._test_minimize_loss_graph(None, None, num_gpus)
+ def testComplexModel(self, num_gpus=2):
+ # Collective ops doesn't support strategy with one device.
+ if context.num_gpus() < num_gpus:
+ return
+ self._test_complex_model(None, None, num_gpus)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 1133be6d0b..244d1fcec8 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -50,10 +50,12 @@ from tensorflow.contrib.cluster_resolver import TPUClusterResolver
from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib
from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib
from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib
+from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
+from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
@@ -347,17 +349,23 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
adam_optimizer_v1_fn = NamedObject(
- "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
+ "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
-optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn]
+adagrad_optimizer_v1_fn = NamedObject(
+ "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
+ adagrad_optimizer_v1_fn]
adam_optimizer_v2_fn = NamedObject(
- "AdamV2", lambda: adam_v2.AdamOptimizer(0.2, epsilon=1))
+ "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
gradient_descent_optimizer_v2_fn = NamedObject(
"GradientDescentV2",
lambda: gradient_descent_v2.GradientDescentOptimizer(0.2))
-optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn]
+adagrad_optimizer_v2_fn = NamedObject(
+ "AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001))
+optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn,
+ adagrad_optimizer_v2_fn]
graph_and_eager_modes = ["graph", "eager"]
diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py
index 24cb08fb48..9fc1b88955 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_utils.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py
@@ -221,9 +221,12 @@ def split_grads_by_size(threshold_size, device_grads):
return small_grads, large_grads
-# threading.Lock() cannot be pickled and therefore cannot be a field of
-# CollectiveKeys.
+# threading.Lock() and threading.local() cannot be pickled and therefore cannot
+# be a field of CollectiveKeys. Right now _thread_local is not necessary to be
+# an instance member of CollectiveKeys since we always create a new thread for
+# each tower.
_lock = threading.Lock()
+_thread_local = threading.local()
# TODO(yuefengz): use random key starts to avoid reusing keys?
@@ -266,14 +269,12 @@ class CollectiveKeys(object):
# For instance keys without ids
self._instance_key_start = instance_key_start
- self._thread_local = threading.local()
-
def _get_thread_local_object(self):
# We make instance key without key ids thread local so that it will work
# with MirroredStrategy and distribute coordinator.
- if not hasattr(self._thread_local, 'instance_key'):
- self._thread_local.instance_key = self._instance_key_start
- return self._thread_local
+ if not hasattr(_thread_local, 'instance_key'):
+ _thread_local.instance_key = self._instance_key_start
+ return _thread_local
def get_group_key(self, devices):
"""Returns a group key for the set of devices.
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
index 5348512016..157618f72f 100644
--- a/tensorflow/contrib/distribute/python/estimator_training_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -26,21 +26,12 @@ import tempfile
import threading
from absl.testing import parameterized
import numpy as np
-import six
-_portpicker_import_error = None
-try:
- import portpicker # pylint: disable=g-import-not-at-top
-except ImportError as _error: # pylint: disable=invalid-name
- _portpicker_import_error = _error
- portpicker = None
-
-# pylint: disable=g-import-not-at-top
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
from tensorflow.contrib.optimizer_v2 import adagrad
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import estimator_training as dc_training
@@ -57,7 +48,6 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import server_lib
BATCH_SIZE = 10
LABEL_DIMENSION = 2
@@ -73,130 +63,38 @@ EVALUATOR = dc._TaskType.EVALUATOR
WORKER = dc._TaskType.WORKER
PS = dc._TaskType.PS
-original_run_distribute_coordinator = dc.run_distribute_coordinator
-
-
-# TODO(yuefengz): merge this method back to test_util.
-def _create_local_cluster(num_workers,
- num_ps,
- has_eval=False,
- protocol="grpc",
- worker_config=None,
- ps_config=None):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
-
- cluster_dict = {
- "worker": ["localhost:%s" % port for port in worker_ports],
- "ps": ["localhost:%s" % port for port in ps_ports]
- }
- if has_eval:
- cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
-
- cs = server_lib.ClusterSpec(cluster_dict)
-
- workers = [
- server_lib.Server(
- cs,
- job_name="worker",
- protocol=protocol,
- task_index=ix,
- config=worker_config,
- start=True) for ix in range(num_workers)
- ]
- ps_servers = [
- server_lib.Server(
- cs,
- job_name="ps",
- protocol=protocol,
- task_index=ix,
- config=ps_config,
- start=True) for ix in range(num_ps)
- ]
- if has_eval:
- evals = [
- server_lib.Server(
- cs,
- job_name="evaluator",
- protocol=protocol,
- task_index=0,
- config=worker_config,
- start=True)
- ]
- else:
- evals = []
-
- return workers, ps_servers, evals
-
-
-def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
- """Create an in-process cluster that consists of only standard server."""
- # Leave some memory for cuda runtime.
- if has_eval:
- gpu_mem_frac = 0.7 / (num_workers + 1)
- else:
- gpu_mem_frac = 0.7 / num_workers
-
- worker_config = config_pb2.ConfigProto()
- worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
-
- # Enable collective ops which has no impact on non-collective ops.
- # TODO(yuefengz, tucker): removing this after we move the initialization of
- # collective mgr to the session level.
- worker_config.experimental.collective_group_leader = (
- "/job:worker/replica:0/task:0")
-
- ps_config = config_pb2.ConfigProto()
- ps_config.device_count["GPU"] = 0
-
- return _create_local_cluster(
- num_workers,
- num_ps=num_ps,
- has_eval=has_eval,
- worker_config=worker_config,
- ps_config=ps_config,
- protocol="grpc")
-
-
-def _create_cluster_spec(has_chief=False,
- num_workers=1,
- num_ps=0,
- has_eval=False):
- if _portpicker_import_error:
- raise _portpicker_import_error # pylint: disable=raising-bad-type
-
- cluster_spec = {}
- if has_chief:
- cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
- if num_workers:
- cluster_spec[WORKER] = [
- "localhost:%s" % portpicker.pick_unused_port()
- for _ in range(num_workers)
- ]
- if num_ps:
- cluster_spec[PS] = [
- "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
- ]
- if has_eval:
- cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
- return cluster_spec
+original_run_std_server = dc._run_std_server
-def _bytes_to_str(maybe_bytes):
- if isinstance(maybe_bytes, six.string_types):
- return maybe_bytes
- else:
- return str(maybe_bytes, "utf-8")
+class MockOsEnv(dict):
+
+ def __init__(self, *args):
+ self._thread_local = threading.local()
+ super(MockOsEnv, self).__init__(*args)
+
+ def get(self, key, default):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.get(self._thread_local.dict, key, default)
+ else:
+ return dict.get(self, key, default)
+ def __getitem__(self, key):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.__getitem__(self._thread_local.dict, key)
+ else:
+ return dict.__getitem__(self, key)
-def _strip_protocol(target):
- # cluster_spec expects "host:port" strings.
- if "//" in target:
- return target.split("//")[1]
- else:
- return target
+ def __setitem__(self, key, val):
+ if not hasattr(self._thread_local, "dict"):
+ self._thread_local.dict = dict()
+ if key == "TF_CONFIG":
+ return dict.__setitem__(self._thread_local.dict, key, val)
+ else:
+ return dict.__setitem__(self, key, val)
class DistributeCoordinatorIntegrationTest(test.TestCase,
@@ -205,22 +103,20 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=2, has_eval=True)
- cls._cluster_spec = {
- "worker": [
- _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
- ],
- "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
- "evaluator": [
- _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
- ]
- }
def setUp(self):
self._model_dir = tempfile.mkdtemp()
- self._event = threading.Event()
+ self._mock_os_env = MockOsEnv()
+ self._mock_context = test.mock.patch.object(os, "environ",
+ self._mock_os_env)
super(DistributeCoordinatorIntegrationTest, self).setUp()
+ self._mock_context.__enter__()
+
+ def tearDown(self):
+ self._mock_context.__exit__(None, None, None)
+ super(DistributeCoordinatorIntegrationTest, self).tearDown()
def dataset_input_fn(self, x, y, batch_size, shuffle):
@@ -391,43 +287,17 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
self._inspect_train_and_eval_events(estimator)
- def _mock_run_distribute_coordinator(
- self,
- worker_fn,
- strategy,
- eval_fn,
- eval_strategy,
- mode=dc.CoordinatorMode.STANDALONE_CLIENT,
- cluster_spec=None,
- session_config=None):
- # Calls the origial `run_distribute_coordinator` method but gets task config
- # from environment variables and then signals the caller.
- task_type = None
- task_id = None
- if not cluster_spec:
- cluster_spec = None
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- if not cluster_spec:
- cluster_spec = tf_config.get("cluster", {})
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", task_type)
- task_id = int(task_env.get("index", task_id))
- self._event.set()
- original_run_distribute_coordinator(
- worker_fn,
- strategy,
- eval_fn,
- eval_strategy,
- mode=mode,
- cluster_spec=cluster_spec,
- task_type=task_type,
- task_id=task_id,
- session_config=session_config)
-
- def _task_thread(self, train_distribute, eval_distribute):
- with test.mock.patch.object(dc, "run_distribute_coordinator",
- self._mock_run_distribute_coordinator):
+ def _mock_run_std_server(self, *args, **kwargs):
+ ret = original_run_std_server(*args, **kwargs)
+ # Wait for all std servers to be brought up in order to reduce the chance of
+ # remote sessions taking local ports that have been assigned to std servers.
+ self._barrier.wait()
+ return ret
+
+ def _task_thread(self, train_distribute, eval_distribute, tf_config):
+ os.environ["TF_CONFIG"] = json.dumps(tf_config)
+ with test.mock.patch.object(dc, "_run_std_server",
+ self._mock_run_std_server):
self._complete_flow(train_distribute, eval_distribute)
def _run_task_in_thread(self, cluster_spec, task_type, task_id,
@@ -448,13 +318,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
"index": task_id
}
}
- self._event.clear()
t = threading.Thread(
- target=self._task_thread, args=(train_distribute, eval_distribute))
- with test.mock.patch.dict("os.environ",
- {"TF_CONFIG": json.dumps(tf_config)}):
- t.start()
- self._event.wait()
+ target=self._task_thread,
+ args=(train_distribute, eval_distribute, tf_config))
+ t.start()
return t
def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
@@ -489,7 +356,11 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
else:
eval_distribute = None
- cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ cluster_spec = multi_worker_test_base.create_cluster_spec(
+ num_workers=3, num_ps=2, has_eval=True)
+ # 3 workers, 2 ps and 1 evaluator.
+ self._barrier = dc._Barrier(6)
+
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
for task_type, ts in threads.items():
@@ -516,7 +387,10 @@ class DistributeCoordinatorIntegrationTest(test.TestCase,
else:
eval_distribute = None
- cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ cluster_spec = multi_worker_test_base.create_cluster_spec(
+ num_workers=3, num_ps=0, has_eval=True)
+ # 3 workers and 1 evaluator.
+ self._barrier = dc._Barrier(4)
threads = self._run_multiple_tasks_in_threads(
cluster_spec, train_distribute, eval_distribute)
threads[WORKER][0].join()
diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py
index c5acb7ced4..559de97bb1 100644
--- a/tensorflow/contrib/distribute/python/input_ops_test.py
+++ b/tensorflow/contrib/distribute/python/input_ops_test.py
@@ -20,8 +20,6 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.distribute.python import input_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@@ -126,20 +124,6 @@ class AutoShardDatasetTest(test.TestCase):
# contain records in order of files.
self._verifySimpleShardingOutput(dataset, self._record)
- def testParallelInterleave(self):
- dataset = dataset_ops.Dataset.from_tensor_slices(
- self._createTFRecordFiles())
- dataset = dataset.apply(interleave_ops.parallel_interleave(
- readers.TFRecordDataset,
- cycle_length=4,
- block_length=self._num_records))
- dataset = input_ops.auto_shard_dataset(
- dataset, self._num_shards, self._shard_index)
-
- # Since block_length == num records in each file, the output will still
- # contain records in order of files.
- self._verifySimpleShardingOutput(dataset, self._record)
-
def testListfiles(self):
filenames = self._createTFRecordFiles()
file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt"
@@ -171,8 +155,8 @@ class AutoShardDatasetTest(test.TestCase):
dataset = dataset.prefetch(buffer_size=batch_size)
dataset = dataset.shuffle(2 * self._num_files * self._num_records)
dataset = dataset.repeat(num_epochs)
- dataset = dataset.apply(batching.map_and_batch(
- lambda x: x, batch_size=batch_size))
+ dataset = dataset.map(lambda x: x)
+ dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=None)
# Auto shard.
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 9e1762d92c..2e6cd43fd4 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
@@ -66,6 +67,32 @@ def simple_functional_model():
return model
+def multi_inputs_multi_outputs_model():
+ input_a = keras.layers.Input(shape=(16,), name='input_a')
+ input_b = keras.layers.Input(shape=(16,), name='input_b')
+ input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
+ dense = keras.layers.Dense(8, name='dense_1')
+
+ interm_a = dense(input_a)
+ # Read m
+ interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)
+ interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])
+ interm_b = dense(input_b)
+ merged = keras.layers.concatenate([interm_s, interm_b], name='merge')
+ output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
+ output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
+ model = keras.models.Model(
+ inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer=gradient_descent.GradientDescentOptimizer(0.001),
+ metrics={
+ 'dense_2': 'categorical_accuracy',
+ 'dense_3': 'categorical_accuracy'
+ })
+ return model
+
+
def get_ds_train_input_fn():
np.random.seed(_RANDOM_SEED)
(x_train, y_train), _ = testing_utils.get_test_data(
@@ -94,6 +121,49 @@ def get_ds_test_input_fn():
return dataset
+def get_multi_inputs_multi_outputs_data():
+ (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(16,),
+ num_classes=3,
+ random_seed=_RANDOM_SEED)
+ (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(16,),
+ num_classes=2,
+ random_seed=_RANDOM_SEED)
+ (m_train, _), (m_test, _) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(8,),
+ num_classes=2,
+ random_seed=_RANDOM_SEED)
+
+ c_train = keras.utils.to_categorical(c_train)
+ c_test = keras.utils.to_categorical(c_test)
+ d_train = keras.utils.to_categorical(d_train)
+ d_test = keras.utils.to_categorical(d_test)
+
+ train_data = {
+ 'input_a': a_train,
+ 'input_b': b_train,
+ 'input_m': m_train,
+ 'output_c': c_train,
+ 'output_d': d_train
+ }
+ test_data = {
+ 'input_a': a_test,
+ 'input_b': b_test,
+ 'input_m': m_test,
+ 'output_c': c_test,
+ 'output_d': d_test
+ }
+
+ return (train_data, test_data)
+
+
def batch_wrapper(dataset, batch_size, distribution):
# TPUs currently require fully defined input shapes, drop_remainder ensures
# the input will have fully defined shapes.
@@ -121,6 +191,8 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
gfile.MakeDirs(self._base_dir)
self._config = run_config_lib.RunConfig(
tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)
+ self._dist = mirrored_strategy.MirroredStrategy(
+ devices=['/device:GPU:0', '/device:GPU:1'])
def tearDown(self):
writer_cache.FileWriterCache.clear()
@@ -174,6 +246,53 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self):
+ train_data, test_data = get_multi_inputs_multi_outputs_data()
+
+ def train_input_fn():
+ input_dict = {
+ 'input_a': train_data['input_a'],
+ 'input_b': train_data['input_b'],
+ 'input_m': train_data['input_m'].astype(np.str)
+ }
+ output_dict = {
+ 'dense_2': train_data['output_c'],
+ 'dense_3': train_data['output_d']
+ }
+ return dataset_ops.Dataset.from_tensor_slices((input_dict,
+ output_dict)).batch(16)
+
+ def eval_input_fn():
+ input_dict = {
+ 'input_a': test_data['input_a'],
+ 'input_b': test_data['input_b'],
+ 'input_m': test_data['input_m'].astype(np.str)
+ }
+ output_dict = {
+ 'dense_2': test_data['output_c'],
+ 'dense_3': test_data['output_d']
+ }
+ return dataset_ops.Dataset.from_tensor_slices((input_dict,
+ output_dict)).batch(16)
+
+ self.do_test_multi_inputs_multi_outputs_with_input_fn(
+ train_input_fn, eval_input_fn)
+
+ def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn,
+ eval_input_fn):
+ config = run_config_lib.RunConfig(
+ tf_random_seed=_RANDOM_SEED,
+ model_dir=self._base_dir,
+ train_distribute=self._dist)
+ with self.cached_session():
+ model = multi_inputs_multi_outputs_model()
+ est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)
+ baseline_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
+ eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
+
def test_keras_optimizer_with_distribution_strategy(self):
dist = mirrored_strategy.MirroredStrategy(
devices=['/device:GPU:0', '/device:GPU:1'])
@@ -516,6 +635,29 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+ @combinations.generate(combinations.combine(
+ distribution=[combinations.tpu_strategy_one_step],
+ mode=['graph']))
+ def test_dataset_input_shape_fully_defined(self, distribution):
+ with self.cached_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss, distribute=distribution)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ # Input shapes are not fully known. Batch dimension is unknown as we are
+ # not using the drop_remainder argument.
+ dataset = dataset.repeat(100).batch(10)
+
+ with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+
def test_learning_phase_value(self):
# TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
# meaningful values. Currently we don't pass the learning phase if the
@@ -613,14 +755,22 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase,
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
num_samples = 10000
+
+ # Train and predict datasets are created with the same input numpy arrays.
x_train = np.random.rand(num_samples, 1)
y_train = 3 * x_train
x_train = x_train.astype('float32')
y_train = y_train.astype('float32')
+ # The model is built once and the initial weights are saved.
+ # This is used to initialize the model for both the distribution and
+ # non-distribution run.
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(1,)))
+ initial_weights = model.get_weights()
+
def fit_and_predict(with_distribution=None):
- model = keras.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(1,)))
+ model.set_weights(initial_weights)
model.compile(
loss=keras.losses.mean_squared_error,
optimizer=gradient_descent.GradientDescentOptimizer(0.5),
@@ -632,12 +782,14 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase,
train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
y_train))
train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
- # Running only 100 steps instead of the full dataset to keep test
- # duration small.
- model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
+ # We have initialized the model to the same weight for the distribution
+ # and non-distribution run. If you want to initialize the model to
+ # random weights for each run, you need to run the model through the
+ # entire dataset at least once to ensure that the weights converge to
+ # the same value.
+ model.fit(x=train_dataset, epochs=1, steps_per_epoch=10)
weights = model.get_weights()
-
x_predict = [[1.], [2.], [3.], [4.]]
predict_batch_size = 4
if with_distribution:
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 8163494c8e..f7773aff4f 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -86,10 +86,11 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ dataset_fn).make_initializable_iterator()
value, update = distribution.call_for_each_tower(
metric_fn, iterator.get_next())
update = distribution.group(update)
+ self.evaluate(iterator.initializer)
self.evaluate(variables.local_variables_initializer())
# TODO(josh11b): Once we switch to using a global batch size for input,
# replace "distribution.num_towers" with "1".
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index bdac4fb58c..d082d5c419 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
+ def _get_iterator(self, ds):
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
+ self.evaluate(iterator.initializer)
+ return iterator
+
@combinations.generate(
combinations.times(
combinations.distributions_and_v1_optimizers(),
@@ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.group(
@@ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -183,6 +188,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
"dense/kernel", "dense/bias", "beta1_power", "beta2_power",
"dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
"dense/bias/Adam_1"
+ ],
+ "Adagrad": [
+ "dense/kernel/Adagrad", "dense/kernel",
+ "dense/bias/Adagrad", "dense/bias"
]
}
variables = variables_map[optimizer_fn().get_name()]
@@ -240,8 +249,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -334,8 +342,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
return distribution.run_steps_on_dataset(
@@ -428,8 +435,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
def run_step():
initial_loss = lambda: constant_op.constant(1e7)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 0c6805d682..945f450387 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -480,8 +480,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._prefetch_on_device)
else:
return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn), self._devices,
- self._prefetch_on_device)
+ self._call_dataset_fn(dataset_fn),
+ self._devices,
+ self._prefetch_on_device,
+ source_device=device_util.resolve("/device:CPU:0"))
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index c6894e9013..04c712ce1d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -300,9 +300,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
- features = dist.distribute_dataset(
- lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
- ).make_one_shot_iterator().get_next()
+ ds = dist.distribute_dataset(
+ lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
+
+ features = iterator.get_next()
with dist.scope():
result = dist.call_for_each_tower(
@@ -1271,7 +1277,17 @@ class MirroredStrategyDefunTest(test.TestCase):
self.evaluate(device_result))
for defun in defuns:
- self.assertEqual(set(mock_model.variables), set(defun.variables))
+ # PolymorphicFunctions are specialized to the current device stack, so
+ # call_for_each has one trace per device. To check that the expected set
+ # of variables was accessed on each trace, we first retrieve each
+ # device-specific graph function.
+ per_device_graph_functions = dist.call_for_each_tower(
+ defun.get_concrete_function,
+ mock_model, *inputs, run_concurrently=False)
+ for device in devices:
+ graph_function = per_device_graph_functions.get(device=device)
+ self.assertEqual(set(mock_model.variables),
+ set(graph_function.graph.variables))
@test_util.run_in_graph_and_eager_modes()
def testVariableInDefun(self):
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
index 7644acedc9..17b7ab74f6 100644
--- a/tensorflow/contrib/distribute/python/monitor.py
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -51,6 +51,7 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
+ session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 18b4503eff..9f92ba7dde 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -36,9 +36,29 @@ from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
+ASSIGNED_PORTS = set()
+lock = threading.Lock()
+
+
+def pick_unused_port():
+ """Returns an unused and unassigned local port."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ global ASSIGNED_PORTS
+ with lock:
+ while True:
+ port = portpicker.pick_unused_port()
+ if port > 10000 and port not in ASSIGNED_PORTS:
+ ASSIGNED_PORTS.add(port)
+ logging.info('Using local port %r', port)
+ return port
+
+
def _create_cluster(num_workers,
num_ps,
has_chief=False,
@@ -49,8 +69,8 @@ def _create_cluster(num_workers,
"""Creates and starts local servers and returns the cluster_spec dict."""
if _portpicker_import_error:
raise _portpicker_import_error # pylint: disable=raising-bad-type
- worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
- ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ worker_ports = [pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
if num_workers > 0:
@@ -58,9 +78,9 @@ def _create_cluster(num_workers,
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
if has_eval:
- cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
if has_chief:
- cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ cluster_dict['chief'] = ['localhost:%s' % pick_unused_port()]
cs = server_lib.ClusterSpec(cluster_dict)
@@ -139,11 +159,36 @@ def create_in_process_cluster(num_workers,
num_workers,
num_ps=num_ps,
has_chief=has_chief,
+ has_eval=has_eval,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
+def create_cluster_spec(has_chief=False,
+ num_workers=1,
+ num_ps=0,
+ has_eval=False):
+ """Create a cluster spec with tasks with unused local ports."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ cluster_spec = {}
+ if has_chief:
+ cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
+ if num_workers:
+ cluster_spec['worker'] = [
+ 'localhost:%s' % pick_unused_port() for _ in range(num_workers)
+ ]
+ if num_ps:
+ cluster_spec['ps'] = [
+ 'localhost:%s' % pick_unused_port() for _ in range(num_ps)
+ ]
+ if has_eval:
+ cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
+ return cluster_spec
+
+
class MultiWorkerTestBase(test.TestCase):
"""Base class for testing multi node strategy and dataset."""
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index 6e9ba37a19..3064433129 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = distribution.distribute_dataset(
- dataset_fn).make_one_shot_iterator()
+ ds = distribution.distribute_dataset(dataset_fn)
+ if context.executing_eagerly():
+ iterator = ds.make_one_shot_iterator()
+ else:
+ iterator = ds.make_initializable_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
@@ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
+ sess.run(iterator.initializer)
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
deleted file mode 100644
index 1ff60c0762..0000000000
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Extension of prefetching_ops to support more than one device."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import warnings
-
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.contrib.data.python.ops import prefetching_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.util import nest as data_nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.util import nest
-
-
-# pylint: disable=protected-access
-class _PrefetchToDeviceIterator(object):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset.
- one_shot: If true, we make a one shot iterator that's already initialized.
- devices: Devices on which to prefetch.
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server). Only used if one_shot
- is False.
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- one_shot,
- devices,
- buffer_size,
- shared_name=None):
- self._input_dataset = input_dataset
- self._get_next_call_count = 0
- self._one_shot = one_shot
- if shared_name is None:
- shared_name = ""
- self._devices = devices
-
- if self._one_shot:
- self._input_iterator = input_dataset.make_one_shot_iterator()
- else:
- self._input_iterator = iterator_ops.Iterator.from_structure(
- self._input_dataset.output_types, self._input_dataset.output_shapes,
- shared_name, self._input_dataset.output_classes)
- input_iterator_handle = self._input_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self._input_iterator.output_types,
- self._input_iterator.output_shapes,
- self._input_iterator.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- target_device = gen_dataset_ops.iterator_get_device(
- self._input_iterator._iterator_resource)
- self._buffering_resources = []
- for device in nest.flatten(self._devices):
- with ops.device(device):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_prefetch_fn,
- output_types=data_nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes)),
- target_device=target_device,
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=shared_name)
- self._buffering_resources.append(buffer_resource_handle)
-
- if not self._one_shot:
- reset_ops = []
- for buffer_resource in self._buffering_resources:
- reset_ops.append(
- prefetching_ops.function_buffering_resource_reset(buffer_resource))
- with ops.control_dependencies(reset_ops):
- self._initializer = self._input_iterator.make_initializer(
- self._input_dataset)
-
- def get_next(self, name=None):
- """See `tf.data.Iterator.get_next`."""
- self._get_next_call_count += 1
- if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
- warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
-
- flat_result = []
- # TODO(priyag): This will fail if the input size (typically number of
- # batches) is not divisible by number of devices.
- # How do we handle that more gracefully / let the user know?
- for buffer_resource in self._buffering_resources:
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
- buffer_resource,
- output_types=data_nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
-
- ret = sparse.deserialize_sparse_tensors(
- data_nest.pack_sequence_as(self.output_types, flat_ret),
- self.output_types, self.output_shapes, self.output_classes)
-
- for tensor, shape in zip(
- data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
- if isinstance(tensor, ops.Tensor):
- tensor.set_shape(shape)
- flat_result.append(ret)
-
- return nest.pack_sequence_as(self._devices, flat_result)
-
- @property
- def initializer(self):
- if self._one_shot:
- raise NotImplementedError("Can't initialize a one_shot_iterator")
- return self._initializer
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-# pylint: enable=protected-access
-
-
-class _PrefetchToDeviceDataset(dataset_ops.Dataset):
- """A `Dataset` whose iterator prefetches elements to other device(s)."""
-
- def __init__(self, input_dataset, devices, buffer_size):
- self._input_dataset = input_dataset
- self._devices = devices
- self._buffer_size = buffer_size if buffer_size is not None else 1
-
- def make_one_shot_iterator(self):
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=True,
- devices=self._devices,
- buffer_size=self._buffer_size)
-
- def make_initializable_iterator(self, shared_name=None):
- if context.executing_eagerly():
- raise RuntimeError(
- "make_initializable_iterator is not supported when eager "
- "execution is enabled.")
-
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=False,
- devices=self._devices,
- buffer_size=self._buffer_size,
- shared_name=shared_name)
-
- def _as_variant_tensor(self):
- # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
- # transformation methods is called.
- # TODO(mrry): Investigate support for chaining further transformations after
- # the prefetch, including GPU support.
- raise NotImplementedError("`prefetch_to_devices()` must be the last "
- "transformation in a dataset pipeline.")
-
- # TODO(priyag): Fix the output types, shapes and classes to match the result
- # of get_next (which has the additional nesting layer of devices now).
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-def prefetch_to_devices(devices, buffer_size=None):
- """A transformation that prefetches dataset values to the given `devices`.
-
- NOTE: Although the transformation creates a `tf.data.Dataset`, the
- transformation must be the final `Dataset` in the input pipeline.
-
- Args:
- devices: A nested structure of devices on which to prefetch the data. It can
- be a single device name, or a tuple or list of device names.
- buffer_size: (Optional.) The number of elements to buffer on each device.
- Defaults to an automatically chosen value.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
- def _apply_fn(dataset):
- return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
-
- return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
deleted file mode 100644
index bb10b546a1..0000000000
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for prefetching_ops_v2."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.distribute.python import prefetching_ops_v2
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
-from tensorflow.python.platform import test
-
-
-class PrefetchingOpsV2Test(test.TestCase):
-
- def testPrefetchToOneDevice(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToTwoDevicesInAList(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- output = []
- with self.cached_session() as sess:
- for _ in range(5):
- result = sess.run(next_element)
- self.assertEqual(2, len(result))
- output.extend(result)
- self.assertEquals(set(range(10)), set(output))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToTwoDevicesWithReinit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
-
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(5):
- sess.run(next_element)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- sess.run(iterator.initializer)
- for _ in range(5):
- sess.run(next_element)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py
index 5aa19cf6a9..09b351ffa4 100644
--- a/tensorflow/contrib/distribute/python/single_loss_example.py
+++ b/tensorflow/contrib/distribute/python/single_loss_example.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import step_fn
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -59,10 +58,9 @@ def minimize_loss_example(optimizer_fn,
def dataset_fn():
dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
- # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be
+ # TODO(isaprykin): batch with drop_remainder causes shapes to be
# fully defined for TPU. Remove this when XLA supports dynamic shapes.
- return dataset.apply(
- batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True))
+ return dataset.batch(1, drop_remainder=True)
# An Optimizer instance is created either outside or inside model_fn.
outer_optimizer = None
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index 1b5a4f64e5..23bf36184f 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@@ -50,7 +51,11 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
- self._iterator = self._distributed_input.make_one_shot_iterator()
+ if context.executing_eagerly():
+ self._iterator = self._distributed_input.make_one_shot_iterator()
+ else:
+ # TODO(priyag): Expose initializer via some initializer property.
+ self._iterator = self._distributed_input.make_initializable_iterator()
class StandardSingleLossStep(StandardInputStep):
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index f1ada49fa3..1ff9b9ceec 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
run_step = single_loss_step
else:
with self.cached_session() as sess:
+ sess.run(single_loss_step._iterator.initializer)
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 6ba83976fc..a6762e5e87 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -158,7 +158,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
raise ValueError(
'TPU currently requires fully defined shapes. Either use '
'set_shape() on the input tensors or use '
- 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
+ 'dataset.batch(..., drop_remainder=True).')
types = nest.flatten(iterator.output_types)
enqueue_ops = [
@@ -307,6 +307,22 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def num_towers_per_host(self):
return self._tpu_metadata.num_of_cores_per_host
+ @property
+ def between_graph(self):
+ return False
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return True
+
+ @property
+ def should_save_summary(self):
+ return True
+
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
@@ -324,4 +340,3 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
cluster_spec = self._tpu_cluster_resolver.cluster_spec()
if cluster_spec:
session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
-
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index fafa6384a1..a0cd029f51 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@@ -683,7 +683,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next(name=name)
+ data_list = self._iterator.get_next()
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,21 +703,26 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(self, dataset, devices, prefetch_on_device=None):
+ def __init__(
+ self,
+ dataset,
+ devices,
+ prefetch_on_device=None,
+ source_device="/cpu:0",
+ ):
self._devices = devices
+ self._source_device = source_device if source_device is not None else "/cpu:0"
# Default to using prefetching in graph mode, unless specified.
- # TODO(priyag): Enable prefetching in eager mode.
+ # TODO(rohanj): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- if self._prefetch_on_device:
- self._dataset = dataset.apply(
- prefetching_ops_v2.prefetch_to_devices(self._devices))
- else:
+ self._dataset = dataset
+ if not self._prefetch_on_device:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -725,15 +730,33 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
+ # Graph mode prefetching with one shot iterator is disabled.
+ if not context.executing_eagerly():
+ raise ValueError("Cannot create a one shot iterator. Please use "
+ "`make_initializable_iterator()` instead.")
+ # Eager mode prefetching would error out in constructor. Only remaining
+ # cases are non-prefetching eager / graph mode. We delegate to
+ # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ dataset_iterator, self._devices, prefetch_on_device=False)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- dataset_iterator = self._dataset.make_initializable_iterator()
+ # Eager mode generates already initialized iterators. Hence we cannot create
+ # an initializable iterator.
+ if context.executing_eagerly():
+ raise ValueError("Cannot create initializable iterator in Eager mode. "
+ "Please use `make_one_shot_iterator` instead.")
+ if self._prefetch_on_device:
+ dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ self._dataset, self._devices, source_device=self._source_device)
+ else:
+ dataset_iterator = self._dataset.make_initializable_iterator()
return PerDeviceDataIterator(
- dataset_iterator, self._devices, self._prefetch_on_device)
+ dataset_iterator,
+ self._devices,
+ prefetch_on_device=self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -813,7 +836,10 @@ class MultiWorkerDataset(object):
worker_input = input_ops.auto_shard_dataset(
worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
+ worker_input,
+ worker_devices,
+ source_device=worker,
+ prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 15a85a28f5..002d61f46e 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase):
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
- iterator = per_device_dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ iterator = per_device_dataset.make_one_shot_iterator()
+ else:
+ iterator = per_device_dataset.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
for expected_value in expected_values:
next_element = iterator.get_next()
@@ -366,20 +370,14 @@ class PerDeviceDatasetTest(test.TestCase):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
- iterator = per_device_dataset.make_one_shot_iterator()
+ iterator = per_device_dataset.make_initializable_iterator()
+ self.evaluate([iterator.initializer])
- # With prefetching, we cannot guarantee which input ends up on which
- # device, so we verify that the complete set seen on all devices is
- # correct, and equal numbers are distributed to each device.
- combined_actual = []
- combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
- combined_actual.extend(self.evaluate([
- values.select_device(d, next_element) for d in devices]))
- combined_expected.extend(expected_value)
-
- self.assertEqual(set(combined_expected), set(combined_actual))
+ computed_value = self.evaluate(
+ [values.select_device(d, next_element) for d in devices])
+ self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 97c53ae2b9..3ff7da4f89 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -25,7 +25,6 @@ py_library(
"`tf.contrib.distributions` to `tfp.distributions`."),
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
@@ -61,7 +60,6 @@ py_library(
":bijectors_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/learn",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
@@ -166,6 +164,7 @@ cuda_py_test(
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
+ tags = ["notap"],
)
cuda_py_test(
@@ -705,8 +704,8 @@ cuda_py_test(
":bijectors_py",
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
@@ -721,8 +720,8 @@ cuda_py_test(
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/ops/linalg",
],
shard_count = 4,
tags = ["noasan"], # times out, http://b/78588814
@@ -738,8 +737,8 @@ cuda_py_test(
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
@@ -793,8 +792,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -830,8 +829,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -851,8 +850,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -870,8 +869,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -906,8 +905,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -925,10 +924,10 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
@@ -944,8 +943,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -963,8 +962,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -982,8 +981,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1001,8 +1000,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1020,8 +1019,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1039,8 +1038,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1074,8 +1073,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1125,8 +1124,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1160,8 +1159,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1179,8 +1178,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1200,8 +1199,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1220,8 +1219,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1239,8 +1238,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1258,8 +1257,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1277,8 +1276,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1296,8 +1295,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1315,8 +1314,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
index a7bd51430e..1e36b7ff9b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index 8dad80aa64..c32ea9ade7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -93,12 +93,12 @@ class SoftsignBijectorTest(test.TestCase):
bijector.inverse_log_det_jacobian(y, event_ndims=1)))
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softsign(validate_args=True)
assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softsign(validate_args=True)
x = np.linspace(-20., 20., 100).astype(np.float32)
y = np.linspace(-0.99, 0.99, 100).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
index f073f51a69..9b9b3ce2dd 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py
@@ -212,7 +212,7 @@ class DistributionTest(test.TestCase):
def testStrWorksCorrectlyScalar(self):
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
self.assertEqual(
- ("tf.distributions.Normal("
+ ("tfp.distributions.Normal("
"\"Normal/\", "
"batch_shape=(), "
"event_shape=(), "
@@ -221,7 +221,7 @@ class DistributionTest(test.TestCase):
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
self.assertEqual(
- ("tf.distributions.Chi2("
+ ("tfp.distributions.Chi2("
"\"silly/\", " # What a silly name that is!
"batch_shape=(2,), "
"event_shape=(), "
@@ -230,7 +230,7 @@ class DistributionTest(test.TestCase):
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
self.assertEqual(
- ("tf.distributions.Exponential(\"Exponential/\", "
+ ("tfp.distributions.Exponential(\"Exponential/\", "
# No batch shape.
"event_shape=(), "
"dtype=float32)"),
@@ -240,7 +240,7 @@ class DistributionTest(test.TestCase):
mvn_static = tfd.MultivariateNormalDiag(
loc=np.zeros([2, 2]), name="MVN")
self.assertEqual(
- ("tf.distributions.MultivariateNormalDiag("
+ ("tfp.distributions.MultivariateNormalDiag("
"\"MVN/\", "
"batch_shape=(2,), "
"event_shape=(2,), "
@@ -251,7 +251,7 @@ class DistributionTest(test.TestCase):
loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
name="MVN2")
self.assertEqual(
- ("tf.distributions.MultivariateNormalDiag("
+ ("tfp.distributions.MultivariateNormalDiag("
"\"MVN2/\", "
"batch_shape=(?,), " # Partially known.
"event_shape=(3,), "
@@ -261,7 +261,7 @@ class DistributionTest(test.TestCase):
def testReprWorksCorrectlyScalar(self):
normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1))
self.assertEqual(
- ("<tf.distributions.Normal"
+ ("<tfp.distributions.Normal"
" 'Normal/'"
" batch_shape=()"
" event_shape=()"
@@ -270,7 +270,7 @@ class DistributionTest(test.TestCase):
chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly")
self.assertEqual(
- ("<tf.distributions.Chi2"
+ ("<tfp.distributions.Chi2"
" 'silly/'" # What a silly name that is!
" batch_shape=(2,)"
" event_shape=()"
@@ -279,7 +279,7 @@ class DistributionTest(test.TestCase):
exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32))
self.assertEqual(
- ("<tf.distributions.Exponential"
+ ("<tfp.distributions.Exponential"
" 'Exponential/'"
" batch_shape=<unknown>"
" event_shape=()"
@@ -290,7 +290,7 @@ class DistributionTest(test.TestCase):
mvn_static = tfd.MultivariateNormalDiag(
loc=np.zeros([2, 2]), name="MVN")
self.assertEqual(
- ("<tf.distributions.MultivariateNormalDiag"
+ ("<tfp.distributions.MultivariateNormalDiag"
" 'MVN/'"
" batch_shape=(2,)"
" event_shape=(2,)"
@@ -301,7 +301,7 @@ class DistributionTest(test.TestCase):
loc=array_ops.placeholder(shape=[None, 3], dtype=dtypes.float32),
name="MVN2")
self.assertEqual(
- ("<tf.distributions.MultivariateNormalDiag"
+ ("<tfp.distributions.MultivariateNormalDiag"
" 'MVN2/'"
" batch_shape=(?,)" # Partially known.
" event_shape=(3,)"
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
index 196cc41335..13370497ce 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py
@@ -22,7 +22,6 @@ import numpy as np
from scipy import stats
from tensorflow.contrib import distributions
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -30,6 +29,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.platform import test
bs = bijectors
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD
deleted file mode 100644
index 42ecea034d..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD
+++ /dev/null
@@ -1,51 +0,0 @@
-# Description:
-# Internal testing utilities, e.g., computing the correct answer to
-# put in a unit test.
-
-licenses(["notice"]) # Apache 2.0
-
-py_library(
- name = "correlation_matrix_volumes_py",
- srcs = [
- "correlation_matrix_volumes_lib.py",
- ],
- deps = [
- "//tensorflow/contrib/distributions:distributions_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:math_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_binary(
- name = "correlation_matrix_volumes",
- srcs = [
- "correlation_matrix_volumes.py",
- ],
- deps = [
- ":correlation_matrix_volumes_py",
- ],
-)
-
-py_test(
- name = "correlation_matrix_volumes_test",
- size = "medium",
- srcs = ["correlation_matrix_volumes_test.py"],
- tags = [
- "no_pip",
- "optonly",
- ],
- deps = [
- ":correlation_matrix_volumes_py",
- # For statistical testing
- "//tensorflow/contrib/distributions:distributions_py",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- ],
-)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py
deleted file mode 100644
index 2eab51cd30..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Executable to estimate the volume of various sets of correlation matrices.
-
-See correlation_matrix_volumes_lib.py for purpose and methodology.
-
-Invocation example:
-```
-python correlation_matrix_volumes.py --num_samples 1e7
-```
-
-This will compute 10,000,000-sample confidence intervals for the
-volumes of several sets of correlation matrices. Which sets, and the
-desired statistical significance, are hard-coded in this source file.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import pprint
-
-from absl import app
-from absl import flags
-
-from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr
-
-FLAGS = flags.FLAGS
-
-# Float to support giving the number of samples in scientific notation.
-# The production run used for the LKJ test used 1e7 samples.
-flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.')
-
-
-def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42):
- # This wrapper undoes the batching in compute_true_volumes, because
- # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM.
- bounds = {}
- for db in det_bounds:
- bounds[db] = corr.compute_true_volumes(
- [db], dim, num_samples, error_rate=error_rate, seed=seed)[db]
- return bounds
-
-
-# The particular bounds in all three of these functions were chosen by
-# a somewhat arbitrary walk through an empirical tradeoff, for the
-# purpose of testing the LKJ distribution. Setting the determinant
-# bound lower
-# - Covers more of the testee's sample space, and
-# - Increases the probability that the rejection sampler will hit, thus
-# - Decreases the relative error (at a fixed sample count) in the
-# rejection-based volume estimate;
-# but also
-# - Increases the variance of the estimator used in the LKJ test.
-# This latter variance is also affected by the dimension and the
-# tested concentration parameter, and can be compensated for with more
-# compute (expensive) or a looser discrepancy limit (unsatisfying).
-# The values here are the projection of the points in that test design
-# space that ended up getting chosen.
-def compute_3x3_volumes(num_samples):
- det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45]
- return ctv_debatched(
- det_bounds, 3, num_samples, error_rate=5e-7, seed=46)
-
-
-def compute_4x4_volumes(num_samples):
- det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45]
- return ctv_debatched(
- det_bounds, 4, num_samples, error_rate=5e-7, seed=47)
-
-
-def compute_5x5_volumes(num_samples):
- det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4]
- return ctv_debatched(
- det_bounds, 5, num_samples, error_rate=5e-7, seed=48)
-
-
-def main(_):
- full_bounds = {}
- full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples))
- full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples))
- full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples))
- pprint.pprint(full_bounds)
-
-if __name__ == '__main__':
- app.run(main)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py
deleted file mode 100644
index 455e71f00c..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py
+++ /dev/null
@@ -1,323 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Estimating the volume of the correlation matrices with bounded determinant.
-
-Why? Because lkj_test.py tests the sampler for the LKJ distribution
-by estimating the same volume another way.
-
-How? Rejection sampling. Or, more precisely, importance sampling,
-proposing from the uniform distribution on symmetric matrices with
-diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation
-matrix if and only if it is also positive semi-definite.
-
-The samples can then be converted into a confidence interval on the
-volume in question by the [Clopper-Pearson
-method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval),
-also implemented here.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import importlib
-import sys
-
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.distributions import uniform
-from tensorflow.python.ops.distributions import util
-from tensorflow.python.platform import tf_logging
-
-__all__ = [
- "correlation_matrix_volume_rejection_samples",
- "compute_true_volumes",
-]
-
-
-def try_import(name): # pylint: disable=invalid-name
- module = None
- try:
- module = importlib.import_module(name)
- except ImportError as e:
- tf_logging.warning("Could not import %s: %s" % (name, str(e)))
- return module
-
-optimize = try_import("scipy.optimize")
-stats = try_import("scipy.stats")
-
-
-def _psd_mask(x):
- """Computes whether each square matrix in the input is positive semi-definite.
-
- Args:
- x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`.
-
- Returns:
- mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each
- scalar is 1 if the corresponding matrix was PSD, otherwise 0.
- """
- # Allegedly
- # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite
- # it is more efficient to test for positive semi-definiteness by
- # trying to compute the Cholesky decomposition -- the matrix is PSD
- # if you succeed and not PSD if you fail. However, TensorFlow's
- # Cholesky raises an exception if _any_ of the input matrices are
- # not PSD, from which I don't know how to extract _which ones_, so I
- # proceed by explicitly computing all the eigenvalues and checking
- # whether they are all positive or not.
- #
- # Also, as was discussed in the answer, it is somewhat dangerous to
- # treat SPD-ness as binary in floating-point arithmetic. Cholesky
- # factorization can complete and 'look' like everything is fine
- # (e.g., O(1) entries and a diagonal of all ones) but the matrix can
- # have an exponential condition number.
- eigenvalues, _ = linalg_ops.self_adjoint_eig(x)
- return math_ops.cast(
- math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype)
-
-
-def _det_large_enough_mask(x, det_bounds):
- """Returns whether the input matches the given determinant limit.
-
- Args:
- x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`.
- det_bounds: A floating-point `Tensor` that must broadcast to shape
- `[B1, ..., Bn]`, giving the desired lower bound on the
- determinants in `x`.
-
- Returns:
- mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each
- scalar is 1 if the corresponding matrix had determinant above
- the corresponding bound, otherwise 0.
- """
- # For the curious: I wonder whether it is possible and desirable to
- # use a Cholesky decomposition-based algorithm for this, since the
- # only matrices whose determinant this code cares about will be PSD.
- # Didn't figure out how to code that in TensorFlow.
- #
- # Expert opinion is that it would be about twice as fast since
- # Cholesky is roughly half the cost of Gaussian Elimination with
- # Partial Pivoting. But this is less of an impact than the switch in
- # _psd_mask.
- return math_ops.cast(
- linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype)
-
-
-def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed):
- """Returns a uniformly random `Tensor` of "correlation-like" matrices.
-
- A "correlation-like" matrix is a symmetric square matrix with all entries
- between -1 and 1 (inclusive) and 1s on the main diagonal. Of these,
- the ones that are positive semi-definite are exactly the correlation
- matrices.
-
- Args:
- num_rows: Python `int` dimension of the correlation-like matrices.
- batch_shape: `Tensor` or Python `tuple` of `int` shape of the
- batch to return.
- dtype: `dtype` of the `Tensor` to return.
- seed: Random seed.
-
- Returns:
- matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]`
- and dtype `dtype`. Each entry is in [-1, 1], and each matrix
- along the bottom two dimensions is symmetric and has 1s on the
- main diagonal.
- """
- num_entries = num_rows * (num_rows + 1) / 2
- ones = array_ops.ones(shape=[num_entries], dtype=dtype)
- # It seems wasteful to generate random values for the diagonal since
- # I am going to throw them away, but `fill_triangular` fills the
- # diagonal, so I probably need them.
- # It's not impossible that it would be more efficient to just fill
- # the whole matrix with random values instead of messing with
- # `fill_triangular`. Then would need to filter almost half out with
- # `matrix_band_part`.
- unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed)
- tril = util.fill_triangular(unifs)
- symmetric = tril + array_ops.matrix_transpose(tril)
- diagonal_ones = array_ops.ones(
- shape=util.pad(batch_shape, axis=0, back=True, value=num_rows),
- dtype=dtype)
- return array_ops.matrix_set_diag(symmetric, diagonal_ones)
-
-
-def correlation_matrix_volume_rejection_samples(
- det_bounds, dim, sample_shape, dtype, seed):
- """Returns rejection samples from trying to get good correlation matrices.
-
- The proposal being rejected from is the uniform distribution on
- "correlation-like" matrices. We say a matrix is "correlation-like"
- if it is a symmetric square matrix with all entries between -1 and 1
- (inclusive) and 1s on the main diagonal. Of these, the ones that
- are positive semi-definite are exactly the correlation matrices.
-
- The rejection algorithm, then, is to sample a `Tensor` of
- `sample_shape` correlation-like matrices of dimensions `dim` by
- `dim`, and check each one for (i) being a correlation matrix (i.e.,
- PSD), and (ii) having determinant at least the corresponding entry
- of `det_bounds`.
-
- Args:
- det_bounds: A `Tensor` of lower bounds on the determinants of
- acceptable matrices. The shape must broadcast with `sample_shape`.
- dim: A Python `int` dimension of correlation matrices to sample.
- sample_shape: Python `tuple` of `int` shape of the samples to
- compute, excluding the two matrix dimensions.
- dtype: The `dtype` in which to do the computation.
- seed: Random seed.
-
- Returns:
- weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the
- corresponding matrix was not a correlation matrix, or had too
- small of a determinant. Otherwise, the entry is the
- multiplicative inverse of the density of proposing that matrix
- uniformly, i.e., the volume of the set of `dim` by `dim`
- correlation-like matrices.
- volume: The volume of the set of `dim` by `dim` correlation-like
- matrices.
- """
- with ops.name_scope("rejection_sampler"):
- rej_proposals = _uniform_correlation_like_matrix(
- dim, sample_shape, dtype, seed=seed)
- rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.)
- # The density of proposing any given point is 1 / rej_proposal_volume;
- # The weight of that point should be scaled by
- # 1 / density = rej_proposal_volume.
- rej_weights = rej_proposal_volume * _psd_mask(
- rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds)
- return rej_weights, rej_proposal_volume
-
-
-def _clopper_pearson_confidence_interval(samples, error_rate):
- """Computes a confidence interval for the mean of the given 1-D distribution.
-
- Assumes (and checks) that the given distribution is Bernoulli, i.e.,
- takes only two values. This licenses using the CDF of the binomial
- distribution for the confidence, which is tighter (for extreme
- probabilities) than the DKWM inequality. The method is known as the
- [Clopper-Pearson method]
- (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval).
-
- Assumes:
-
- - The given samples were drawn iid from the distribution of interest.
-
- - The given distribution is a Bernoulli, i.e., supported only on
- low and high.
-
- Guarantees:
-
- - The probability (over the randomness of drawing the given sample)
- that the true mean is outside the returned interval is no more
- than the given error_rate.
-
- Args:
- samples: `np.ndarray` of samples drawn iid from the distribution
- of interest.
- error_rate: Python `float` admissible rate of mistakes.
-
- Returns:
- low: Lower bound of confidence interval.
- high: Upper bound of confidence interval.
-
- Raises:
- ValueError: If `samples` has rank other than 1 (batch semantics
- are not implemented), or if `samples` contains values other than
- `low` or `high` (as that makes the distribution not Bernoulli).
- """
- # TODO(b/78025336) Migrate this confidence interval function
- # to statistical_testing.py. In order to do that
- # - Get the binomial CDF from the Binomial distribution
- # - Implement scalar root finding in TF. Batch bisection search
- # shouldn't be too hard, and is definitely good enough for this
- # problem. Batching the Brent algorithm (from scipy) that is used
- # here may be more involved, but may also not be necessary---it's
- # only used here because scipy made it convenient. In particular,
- # robustness is more important than speed here, which may make
- # bisection search actively better.
- # - The rest is just a matter of rewriting in the appropriate style.
- if optimize is None or stats is None:
- raise ValueError(
- "Scipy is required for computing Clopper-Pearson confidence intervals")
- if len(samples.shape) != 1:
- raise ValueError("Batch semantics not implemented")
- n = len(samples)
- low = np.amin(samples)
- high = np.amax(samples)
- successes = np.count_nonzero(samples - low)
- failures = np.count_nonzero(samples - high)
- if successes + failures != n:
- uniques = np.unique(samples)
- msg = ("Purportedly Bernoulli distribution had distinct samples"
- " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2]))
- raise ValueError(msg)
- def p_small_enough(p):
- prob = stats.binom.logcdf(successes, n, p)
- return prob - np.log(error_rate / 2.)
- def p_big_enough(p):
- prob = stats.binom.logsf(successes, n, p)
- return prob - np.log(error_rate / 2.)
- high_p = optimize.brentq(
- p_small_enough, float(successes) / n, 1., rtol=1e-9)
- low_p = optimize.brentq(
- p_big_enough, 0., float(successes) / n, rtol=1e-9)
- low_interval = low + (high - low) * low_p
- high_interval = low + (high - low) * high_p
- return (low_interval, high_interval)
-
-
-def compute_true_volumes(
- det_bounds, dim, num_samples, error_rate=1e-6, seed=42):
- """Returns confidence intervals for the desired correlation matrix volumes.
-
- The confidence intervals are computed by the [Clopper-Pearson method]
- (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval).
-
- Args:
- det_bounds: A rank-1 numpy array of lower bounds on the
- determinants of acceptable matrices. Entries must be unique.
- dim: A Python `int` dimension of correlation matrices to sample.
- num_samples: The number of samples to draw.
- error_rate: The statistical significance of the returned
- confidence intervals. The significance is broadcast: Each
- returned interval separately may be incorrect with probability
- (under the sample of correlation-like matrices drawn internally)
- at most `error_rate`.
- seed: Random seed.
-
- Returns:
- bounds: A Python `dict` mapping each determinant bound to the low, high
- tuple giving the confidence interval.
- """
- bounds = {}
- with session.Session() as sess:
- rej_weights, _ = correlation_matrix_volume_rejection_samples(
- det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed)
- rej_weights = sess.run(rej_weights)
- for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds):
- template = ("Estimating volume of {}x{} correlation "
- "matrices with determinant >= {}.")
- print(template.format(dim, dim, det))
- sys.stdout.flush()
- bounds[det] = _clopper_pearson_confidence_interval(
- rw, error_rate=error_rate)
- return bounds
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py
deleted file mode 100644
index 8f99300e63..0000000000
--- a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for correlation_matrix_volumes_lib.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr
-from tensorflow.contrib.distributions.python.ops import statistical_testing as st
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.platform import test
-
-
-# NxN correlation matrices are determined by the N*(N-1)/2
-# lower-triangular entries. In addition to being between -1 and 1,
-# they must also obey the constraint that the determinant of the
-# resulting symmetric matrix is non-negative. In 2x2, we can even
-# analytically compute the volume when the determinant is bounded to >
-# epsilon, as that boils down to the one lower-triangular entry being
-# less than 1 - epsilon in absolute value.
-def two_by_two_volume(det_bound):
- return 2 * np.sqrt(1.0 - det_bound)
-
-
-# The post
-# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/
-# derives (with elementary calculus) that the volume (with respect to
-# Lebesgue^3 measure) of the set of 3x3 correlation matrices is
-# pi^2/2. The same result is also obtained by [1].
-def three_by_three_volume():
- return np.pi**2 / 2.
-
-
-# The volume of the unconstrained set of correlation matrices is also
-# the normalization constant of the LKJ distribution from [2]. As
-# part of defining the distribution, that reference a derives general
-# formula for this volume for all dimensions. A TensorFlow
-# computation thereof gave the below result for 4x4:
-def four_by_four_volume():
- # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0]))
- return 11.6973076
-
-# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of
-# correlation matrices." The American Statistician, 48(4), 276-279.
-
-# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating
-# random correlation matrices based on vines and extended onion
-# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001.
-
-
-class CorrelationMatrixVolumesTest(test.TestCase):
-
- def testRejection2D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array(
- [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32)
- exact_volumes = two_by_two_volume(det_bounds)
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43)
- # shape of rej_weights: [num_samples, 9, 2, 2]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- # Correct the false fail rate due to different broadcasting
- false_fail_rate=1.1e-7, false_pass_rate=1e-6),
- 0.036)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testRejection3D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array([0.0], dtype=np.float32)
- exact_volumes = np.array([three_by_three_volume()], dtype=np.float32)
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44)
- # shape of rej_weights: [num_samples, 1, 3, 3]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- # Going for about a 3% relative error
- 0.15)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testRejection4D(self):
- num_samples = int(1e5) # Chosen for a small min detectable discrepancy
- det_bounds = np.array([0.0], dtype=np.float32)
- exact_volumes = [four_by_four_volume()]
- (rej_weights,
- rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples(
- det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45)
- # shape of rej_weights: [num_samples, 1, 4, 4]
- chk1 = st.assert_true_mean_equal_by_dkwm(
- rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes,
- false_fail_rate=1e-6)
- chk2 = check_ops.assert_less(
- st.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=rej_proposal_volume,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- # Going for about a 10% relative error
- 1.1)
- with ops.control_dependencies([chk1, chk2]):
- rej_weights = array_ops.identity(rej_weights)
- self.evaluate(rej_weights)
-
- def testVolumeEstimation2D(self):
- # Test that the confidence intervals produced by
- # corr.compte_true_volumes are sound, in the sense of containing
- # the exact volume.
- num_samples = int(1e5) # Chosen by symmetry with testRejection2D
- det_bounds = np.array(
- [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32)
- volume_bounds = corr.compute_true_volumes(
- det_bounds, 2, num_samples, error_rate=1e-6, seed=47)
- exact_volumes = two_by_two_volume(det_bounds)
- for det, volume in zip(det_bounds, exact_volumes):
- computed_low, computed_high = volume_bounds[det]
- self.assertLess(computed_low, volume)
- self.assertGreater(computed_high, volume)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py
index bb9b8043b2..3ba1c3a665 100644
--- a/tensorflow/contrib/distributions/python/ops/autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py
@@ -65,13 +65,14 @@ class Autoregressive(distribution_lib.Distribution):
```
where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn`
- constructs a `tf.distributions.Distribution`-like instance, and `x0` is a
+ constructs a `tfp.distributions.Distribution`-like instance, and `x0` is a
fixed initializing `Tensor`.
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
def normal_fn(self, event_size):
n = event_size * (event_size + 1) / 2
@@ -127,7 +128,7 @@ class Autoregressive(distribution_lib.Distribution):
Args:
distribution_fn: Python `callable` which constructs a
- `tf.distributions.Distribution`-like instance from a `Tensor` (e.g.,
+ `tfp.distributions.Distribution`-like instance from a `Tensor` (e.g.,
`sample0`). The function must respect the "autoregressive property",
i.e., there exists a permutation of event such that each coordinate is a
diffeomorphic function of on preceding coordinates.
diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
index 519077bc9a..612376efb7 100644
--- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py
@@ -45,7 +45,8 @@ class BatchReshape(distribution_lib.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
dtype = np.float32
dims = 2
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
index 25f29452c3..ba31697c58 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape
from tensorflow.python.framework import dtypes
@@ -29,6 +28,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
index 296e66f2b2..3b3d8ee6f2 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py
@@ -61,8 +61,8 @@ class MaskedAutoregressiveFlow(bijector.Bijector):
`shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves
this property by zeroing out weights in its `masked_dense` layers.
- In the `tf.distributions` framework, a "normalizing flow" is implemented as a
- `tf.contrib.distributions.bijectors.Bijector`. The `forward` "autoregression"
+ In the `tfp` framework, a "normalizing flow" is implemented as a
+ `tfp.bijectors.Bijector`. The `forward` "autoregression"
is implemented using a `tf.while_loop` and a deep neural network (DNN) with
masked weights such that the autoregressive property is automatically met in
the `inverse`.
@@ -126,8 +126,9 @@ class MaskedAutoregressiveFlow(bijector.Bijector):
#### Examples
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
dims = 5
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
index f182a1adcb..178c3c94bf 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py
@@ -41,9 +41,10 @@ class Permute(bijector.Bijector):
"""Permutes the rightmost dimension of a `Tensor`.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfb = tfp.bijectors
- reverse = tfd.bijectors.Permute(permutation=[2, 1, 0])
+ reverse = tfb.Permute(permutation=[2, 1, 0])
reverse.forward([-1., 0., 1.])
# ==> [1., 0., -1]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
index 773ae24461..0bcb08cdea 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py
@@ -90,8 +90,9 @@ class RealNVP(bijector.Bijector):
#### Example Use
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
# A common choice for a normalizing flow is to use a Gaussian for the base
# distribution. (However, any continuous distribution would work.) E.g.,
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
index c8282229a3..71ac29038f 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py
@@ -80,9 +80,10 @@ class Reshape(bijector.Bijector):
Example usage:
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfb = tfp.bijectors
- r = tfd.bijectors.Reshape(event_shape_out=[1, -1])
+ r = tfb.Reshape(event_shape_out=[1, -1])
r.forward([3., 4.]) # shape [2]
# ==> [[3., 4.]] # shape [1, 2]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
index 6fbe866578..0a6d690b65 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py
@@ -42,7 +42,10 @@ class ScaleTriL(chain.Chain):
#### Examples
```python
- tfb = tf.contrib.distributions.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
+
b = tfb.ScaleTriL(
diag_bijector=tfb.Exp(),
diag_shift=None)
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index cb5223b055..c461833b9a 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -63,7 +63,8 @@ class Cauchy(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Cauchy distribution.
dist = tfd.Cauchy(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index affc64a14f..507c5d3679 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -198,8 +198,11 @@ class Deterministic(_BaseDeterministic):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single Deterministic supported at zero.
- constant = tf.contrib.distributions.Deterministic(0.)
+ constant = tfd.Deterministic(0.)
constant.prob(0.)
==> 1.
constant.prob(2.)
@@ -208,7 +211,7 @@ class Deterministic(_BaseDeterministic):
# Initialize a [2, 2] batch of scalar constants.
loc = [[0., 1.], [2., 3.]]
x = [[0., 1.1], [1.99, 3.]]
- constant = tf.contrib.distributions.Deterministic(loc)
+ constant = tfd.Deterministic(loc)
constant.prob(x)
==> [[1., 0.], [0., 1.]]
```
@@ -310,7 +313,8 @@ class VectorDeterministic(_BaseDeterministic):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single VectorDeterministic supported at [0., 2.] in R^2.
constant = tfd.Deterministic([0., 2.])
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 6959b3e877..b4ad33cf6d 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import linalg
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
@@ -27,6 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.ops.distributions import distribution as distribution_lib
# The following two lines are redundant, in a sense. The first enables
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index acdea4d61d..4b50df5b48 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -63,7 +63,8 @@ class _Gumbel(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Gumbel distribution.
dist = tfd.Gumbel(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py
index b02c403106..f121637086 100644
--- a/tensorflow/contrib/distributions/python/ops/half_normal.py
+++ b/tensorflow/contrib/distributions/python/ops/half_normal.py
@@ -66,15 +66,18 @@ class HalfNormal(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar HalfNormal distribution.
- dist = tf.contrib.distributions.HalfNormal(scale=3.0)
+ dist = tfd.HalfNormal(scale=3.0)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued HalfNormals.
# The first has scale 11.0, the second 22.0
- dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0])
+ dist = tfd.HalfNormal(scale=[11.0, 22.0])
# Evaluate the pdf of the first distribution on 1.0, and the second on 1.5,
# returning a length two tensor.
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index 0672702b96..e1cfff3c66 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -70,7 +70,8 @@ class Independent(distribution_lib.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Make independent distribution from a 2-batch Normal.
ind = tfd.Independent(
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 70d050d7a6..452628257e 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -89,7 +89,9 @@ class InverseGamma(distribution.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
dist = tfd.InverseGamma(concentration=3.0, rate=2.0)
dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index 02e3bad51e..21c9b5a354 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -61,7 +61,8 @@ class Logistic(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Define a single scalar Logistic distribution.
dist = tfd.Logistic(loc=0., scale=3.)
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index 3b7114ef06..52b67f2c54 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -50,7 +50,9 @@ class Mixture(distribution.Distribution):
```python
# Create a mixture of two Gaussians:
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
mix = 0.3
bimix_gauss = tfd.Mixture(
cat=tfd.Categorical(probs=[mix, 1.-mix]),
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 8ffee940d0..f4d394ff29 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -44,7 +44,8 @@ class MixtureSameFamily(distribution.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
### Create a mixture of two scalar Gaussians:
@@ -113,12 +114,12 @@ class MixtureSameFamily(distribution.Distribution):
"""Construct a `MixtureSameFamily` distribution.
Args:
- mixture_distribution: `tf.distributions.Categorical`-like instance.
+ mixture_distribution: `tfp.distributions.Categorical`-like instance.
Manages the probability of selecting components. The number of
categories must match the rightmost batch dimension of the
`components_distribution`. Must have either scalar `batch_shape` or
`batch_shape` matching `components_distribution.batch_shape[:-1]`.
- components_distribution: `tf.distributions.Distribution`-like instance.
+ components_distribution: `tfp.distributions.Distribution`-like instance.
Right-most batch dimension indexes components.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index cd0c282ba6..0b5b76be92 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -85,7 +85,8 @@ class MultivariateNormalDiag(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate Gaussian.
mvn = tfd.MultivariateNormalDiag(
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index d8401801f2..80546083d3 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
from tensorflow.python.framework import ops
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.util import deprecation
@@ -87,7 +87,8 @@ class MultivariateNormalDiagPlusLowRank(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
# `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index dbc4c1b3dc..bcb4937980 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -73,7 +73,8 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index efe5a6d0d9..8fdc99824b 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -91,7 +91,8 @@ class MultivariateNormalLinearOperator(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index d9110947ec..c21f70fc3b 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
from tensorflow.python.framework import ops
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.util import deprecation
@@ -77,13 +77,14 @@ class MultivariateNormalTriL(
```
Trainable (batch) lower-triangular matrices can be created with
- `tf.contrib.distributions.matrix_diag_transform()` and/or
- `tf.contrib.distributions.fill_triangular()`
+ `tfp.distributions.matrix_diag_transform()` and/or
+ `tfp.distributions.fill_triangular()`
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index 7a7ad1be35..85683e3233 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -220,7 +220,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Create two batches of PoissonLogNormalQuadratureCompounds, one with
# prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.`
diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
index 18a0f754e6..134658deab 100644
--- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py
@@ -196,8 +196,9 @@ class QuantizedDistribution(distributions.Distribution):
parameter determining the unnormalized probability of that component.
```python
- tfd = tf.contrib.distributions
- tfb = tfd.bijectors
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+ tfb = tfp.bijectors
net = wavenet(inputs)
loc, unconstrained_scale, logits = tf.split(net,
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index a9d0fb4ccf..4b520b912e 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -124,7 +124,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight)
distribution: `tf.Distribution`-like instance. Distribution that is
transformed to produce this distribution.
- Default is `tf.distributions.Normal(0., 1.)`.
+ Default is `tfp.distributions.Normal(0., 1.)`.
Must be a scalar-batch, scalar-event distribution. Typically
`distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
index c25e8c51d7..af22f4843a 100644
--- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py
+++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
@@ -30,27 +30,27 @@ is some expected constant. Suppose the support of P is the interval
`[0, 1]`. Then you might do this:
```python
-tfd = tf.contrib.distributions
-
-expected_mean = ...
-num_samples = 5000
-samples = ... draw 5000 samples from P
-
-# Check that the mean looks right
-check1 = tfd.assert_true_mean_equal_by_dkwm(
- samples, low=0., high=1., expected=expected_mean,
- false_fail_rate=1e-6)
-
-# Check that the difference in means detectable with 5000 samples is
-# small enough
-check2 = tf.assert_less(
- tfd.min_discrepancy_of_true_means_detectable_by_dkwm(
- num_samples, low=0., high=1.0,
- false_fail_rate=1e-6, false_pass_rate=1e-6),
- 0.01)
-
-# Be sure to execute both assertion ops
-sess.run([check1, check2])
+ from tensorflow_probability.python.distributions.internal import statistical_testing
+
+ expected_mean = ...
+ num_samples = 5000
+ samples = ... draw 5000 samples from P
+
+ # Check that the mean looks right
+ check1 = statistical_testing.assert_true_mean_equal_by_dkwm(
+ samples, low=0., high=1., expected=expected_mean,
+ false_fail_rate=1e-6)
+
+ # Check that the difference in means detectable with 5000 samples is
+ # small enough
+ check2 = tf.assert_less(
+ statistical_testing.min_discrepancy_of_true_means_detectable_by_dkwm(
+ num_samples, low=0., high=1.0,
+ false_fail_rate=1e-6, false_pass_rate=1e-6),
+ 0.01)
+
+ # Be sure to execute both assertion ops
+ sess.run([check1, check2])
```
The second assertion is an instance of experiment design. It's a
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index ece03fe4aa..a3d178357b 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -23,7 +23,6 @@ import numpy as np
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
-from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -36,6 +35,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import categorical as categorical_lib
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import normal as normal_lib
+from tensorflow.python.ops.linalg import linear_operator_addition as linop_add_lib
from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib
from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib
from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib
@@ -300,7 +300,8 @@ class VectorDiffeomixture(distribution_lib.Distribution):
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Create two batches of VectorDiffeomixtures, one with mix_loc=[0.],
# another with mix_loc=[1]. In both cases, `K=2` and the affine
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index 73356a3625..36cbd71f8b 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -90,7 +90,8 @@ class VectorExponentialDiag(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorExponential, supported on
# {(x, y) in R^2 : x > 0, y > 0}.
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index 9a47b48557..fd5bf9ecc7 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -108,7 +108,8 @@ class VectorExponentialLinearOperator(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorExponential, supported on
# {(x, y) in R^2 : x > 0, y > 0}.
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index e68ddc569c..8cd4e128c7 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -102,7 +102,8 @@ class VectorLaplaceDiag(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 2-variate VectorLaplace.
vla = tfd.VectorLaplaceDiag(
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index 3923161a33..67d2ccd28d 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -110,7 +110,8 @@ class VectorLaplaceLinearOperator(
#### Examples
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate VectorLaplace with some desired covariance.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index 49ffff24ca..da57d0cb55 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -152,7 +152,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
broadcastable with `event_shape`.
distribution: `tf.Distribution`-like instance. Distribution from which `k`
iid samples are used as input to transformation `F`. Default is
- `tf.distributions.Normal(loc=0., scale=1.)`.
+ `tfp.distributions.Normal(loc=0., scale=1.)`.
Must be a scalar-batch, scalar-event distribution. Typically
`distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index f289b39e51..bad91a0844 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -92,7 +92,8 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
Extra leading dimensions, if provided, allow for batches.
```python
- tfd = tf.contrib.distributions
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
# Initialize a single 3-variate vector Student's t-distribution.
mu = [1., 2, 3]
diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py
index f1accaaa4c..ee2fc58864 100644
--- a/tensorflow/contrib/distributions/python/ops/wishart.py
+++ b/tensorflow/contrib/distributions/python/ops/wishart.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import math
import numpy as np
-from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import distribution_util
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.python.framework import constant_op
@@ -36,6 +35,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.linalg import linalg
from tensorflow.python.util import deprecation
__all__ = [
@@ -480,11 +480,14 @@ class WishartCholesky(_WishartLinearOperator):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single 3x3 Wishart with Cholesky factored scale matrix and 5
# degrees-of-freedom.(*)
df = 5
chol_scale = tf.cholesky(...) # Shape is [3, 3].
- dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale)
+ dist = tfd.WishartCholesky(df=df, scale=chol_scale)
# Evaluate this on an observation in R^3, returning a scalar.
x = ... # A 3x3 positive definite matrix.
@@ -498,14 +501,14 @@ class WishartCholesky(_WishartLinearOperator):
# Initialize two 3x3 Wisharts with Cholesky factored scale matrices.
df = [5, 4]
chol_scale = tf.cholesky(...) # Shape is [2, 3, 3].
- dist = tf.contrib.distributions.WishartCholesky(df=df, scale=chol_scale)
+ dist = tfd.WishartCholesky(df=df, scale=chol_scale)
# Evaluate this on four observations.
x = [[x0, x1], [x2, x3]] # Shape is [2, 2, 3, 3].
dist.prob(x) # Shape is [2, 2].
# (*) - To efficiently create a trainable covariance matrix, see the example
- # in tf.contrib.distributions.matrix_diag_transform.
+ # in tfp.distributions.matrix_diag_transform.
```
"""
@@ -604,11 +607,14 @@ class WishartFull(_WishartLinearOperator):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Initialize a single 3x3 Wishart with Full factored scale matrix and 5
# degrees-of-freedom.(*)
df = 5
scale = ... # Shape is [3, 3]; positive definite.
- dist = tf.contrib.distributions.WishartFull(df=df, scale=scale)
+ dist = tfd.WishartFull(df=df, scale=scale)
# Evaluate this on an observation in R^3, returning a scalar.
x = ... # A 3x3 positive definite matrix.
@@ -622,14 +628,14 @@ class WishartFull(_WishartLinearOperator):
# Initialize two 3x3 Wisharts with Full factored scale matrices.
df = [5, 4]
scale = ... # Shape is [2, 3, 3].
- dist = tf.contrib.distributions.WishartFull(df=df, scale=scale)
+ dist = tfd.WishartFull(df=df, scale=scale)
# Evaluate this on four observations.
x = [[x0, x1], [x2, x3]] # Shape is [2, 2, 3, 3]; xi is positive definite.
dist.prob(x) # Shape is [2, 2].
# (*) - To efficiently create a trainable covariance matrix, see the example
- # in tf.contrib.distributions.matrix_diag_transform.
+ # in tfd.matrix_diag_transform.
```
"""
diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md
index 86d203452e..4bd2769e87 100644
--- a/tensorflow/contrib/eager/README.md
+++ b/tensorflow/contrib/eager/README.md
@@ -44,7 +44,6 @@ Installation instructions at https://www.tensorflow.org/install/
For an introduction to eager execution in TensorFlow, see:
-- [User Guide](https://www.tensorflow.org/guide/eager) ([source](../../docs_src/guide/eager.md))
-- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb)
-- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb)
-- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb)
+- [User Guide](https://www.tensorflow.org/guide/eager) ([source](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/index.md))
+- Notebook: [Basic Usage](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb)
+- Notebook: [Automatic differentiation and gradient tape](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb)
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index 84517b57c7..33a1d572a2 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -14,6 +14,7 @@ py_library(
":datasets",
":metrics",
":network",
+ ":parameter_server",
":remote",
":saver",
"//tensorflow/python:framework_ops",
@@ -97,6 +98,18 @@ py_library(
],
)
+py_library(
+ name = "parameter_server",
+ srcs = ["parameter_server.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
cuda_py_test(
name = "saver_test",
srcs = ["saver_test.py"],
@@ -241,6 +254,7 @@ py_test(
srcs = ["remote_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":parameter_server",
":remote",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
index 529c99b37c..3acecd283c 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -1056,7 +1056,7 @@
"\n",
" attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n",
"\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(predictions[0]).numpy()\n",
" result.append(index_word[predicted_id])\n",
"\n",
" if index_word[predicted_id] == '<end>':\n",
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
index 40bc098724..e0d5e494d4 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb
@@ -610,7 +610,7 @@
"\n",
" # using a multinomial distribution to predict the word returned by the model\n",
" predictions = predictions / temperature\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(predictions[0]).numpy()\n",
" \n",
" # We pass the predicted word as the next input to the model\n",
" # along with the previous hidden state\n",
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
index 9557479885..c38a1597b8 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py
@@ -37,26 +37,32 @@ def get_default_hparams():
n_warmup_iters=3)
+def step(dynamics, optimizer, samples):
+ loss, grads, samples, _ = l2hmc.loss_and_grads(
+ dynamics, samples, loss_fn=l2hmc.compute_loss)
+ optimizer.apply_gradients(zip(grads, dynamics.variables))
+
+ return loss, samples
+
+
def warmup(dynamics,
optimizer,
n_iters=1,
n_samples=200,
- loss_fn=l2hmc.compute_loss):
+ step_fn=step):
"""Warmup optimization to reduce overhead."""
samples = tf.random_normal(
shape=[n_samples, dynamics.x_dim], dtype=tf.float32)
for _ in range(n_iters):
- _, grads, samples, _ = l2hmc.loss_and_grads(
- dynamics, samples, loss_fn=loss_fn)
- optimizer.apply_gradients(zip(grads, dynamics.variables))
+ _, samples = step_fn(dynamics, optimizer, samples)
def fit(dynamics,
samples,
optimizer,
- loss_fn=l2hmc.compute_loss,
+ step_fn=step,
n_iters=5000,
verbose=True,
logdir=None):
@@ -66,9 +72,7 @@ def fit(dynamics,
summary_writer = tf.contrib.summary.create_file_writer(logdir)
for i in range(n_iters):
- loss, grads, samples, _ = l2hmc.loss_and_grads(
- dynamics, samples, loss_fn=loss_fn)
- optimizer.apply_gradients(zip(grads, dynamics.variables))
+ loss, samples = step_fn(dynamics, optimizer, samples)
if verbose:
print("Iteration %d: loss %.4f" % (i, loss))
@@ -193,16 +197,16 @@ class L2hmcBenchmark(tf.test.Benchmark):
n_steps=hparams.n_steps,
eps=hparams.eps)
optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate)
- loss_fn = tfe.defun(l2hmc.compute_loss) if defun else l2hmc.compute_loss
+ step_fn = tfe.defun(step) if defun else step
# Warmup to reduce initialization effect when timing
- warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn)
+ warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, step_fn=step_fn)
# Training
samples = tf.random_normal(
shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32)
start_time = time.time()
- fit(dynamics, samples, optimizer, loss_fn=loss_fn, n_iters=hparams.n_iters)
+ fit(dynamics, samples, optimizer, step_fn=step_fn, n_iters=hparams.n_iters)
wall_time = time.time() - start_time
examples_per_sec = hparams.n_samples / wall_time
diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
index f1e1f99c57..560fc8c5a2 100644
--- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb
@@ -677,7 +677,7 @@
" attention_weights = tf.reshape(attention_weights, (-1, ))\n",
" attention_plot[t] = attention_weights.numpy()\n",
"\n",
- " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n",
+ " predicted_id = tf.argmax(predictions[0]).numpy()\n",
"\n",
" result += targ_lang.idx2word[predicted_id] + ' '\n",
"\n",
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
index fabd7b3e20..750bbc66f3 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md
@@ -23,4 +23,4 @@ Attribution-ShareAlike License and is available at
https://en.wikipedia.org/wiki/List_of_colors:_N-Z
This example was adapted from
- https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot
+ https://github.com/random-forests/tensorflow-workshop/tree/master/archive/extras/colorbot
diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py
new file mode 100644
index 0000000000..3a9e7b027e
--- /dev/null
+++ b/tensorflow/contrib/eager/python/parameter_server.py
@@ -0,0 +1,289 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""EXPERIMENTAL utilities for parameter server training with eager execution.
+
+Note: this should eventually be merged with the distribution strategy for
+ParameterServer.
+"""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import time
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training.checkpointable import base as checkpointable
+
+
+def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
+ """Creates a variable handle with information to do shape inference."""
+ container = ops.get_default_graph()._container # pylint: disable=protected-access
+ if container is None:
+ container = ""
+ handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
+ shared_name=shared_name,
+ name=name,
+ container=container)
+ if graph_mode:
+ return handle
+
+ with context.graph_mode(), ops.Graph().as_default() as graph:
+ h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
+ shared_name=shared_name,
+ name=name,
+ container=container)
+
+ # Tensor._handle_data contains information for the shape-inference code to
+ # know the shape and dtype of the variable pointed to by a handle. Since
+ # shape inference doesn't run in eager mode we copy this data here for when
+ # the handle is captured by an eager mode function.
+ # pylint: disable=protected-access
+ if ops._USE_C_SHAPES:
+ handle._handle_data = resource_variable_ops.get_resource_handle_data(h)
+ else:
+ if h._handle_data is None:
+ ops.set_shape_and_handle_data_for_outputs(h.op)
+ handle._handle_data = h._handle_data
+ # pylint: enable=protected-access
+ # Clean up op->graph->op reference cycles.
+ ops.dismantle_graph(graph)
+ return handle
+
+
+class SharedVariable(resource_variable_ops.ResourceVariable):
+ """Experimental Variable designed for parameter server training.
+
+ A SharedVariable has a name and two instances of SharedVariable with the
+ same name will have the same value, even if they are in different Sessions,
+ as long as they are placed on the same device.
+
+ The storage associated with SharedVariables is also not deleted when they go
+ out of scope.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ name=None,
+ dtype=None,
+ constraint=None,
+ initialize=True,
+ **unused_kwargs):
+ """Creates a variable.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called.
+ (Note that initializer functions from init_ops.py must first be bound
+ to a shape before being used here.)
+ trainable: If `True`, automatically watches this variable on GradientTape
+ whenever it's used.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If None, either the datatype will be kept (if initial_value is
+ a Tensor) or float32 will be used (if it is a Python object convertible
+ to a Tensor).
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+ initialize: if True, runs initialization in eager execution; leaves the
+ variable uninitialized otherwise.
+
+ Raises:
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ """
+ if initial_value is None:
+ raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+
+ if isinstance(initial_value, ops.Tensor) and hasattr(
+ initial_value, "graph") and initial_value.graph.building_function:
+ raise ValueError("Tensor-typed variable initializers must either be "
+ "wrapped in an init_scope or callable "
+ "(e.g., `tf.Variable(lambda : "
+ "tf.truncated_normal([10, 40]))`) when building "
+ "functions. Please file a feature request if this "
+ "restriction inconveniences you.")
+
+ if constraint is not None and not callable(constraint):
+ raise ValueError("The `constraint` argument must be a callable.")
+
+ if isinstance(initial_value, checkpointable.CheckpointInitialValue):
+ self._maybe_initialize_checkpointable()
+ self._update_uid = initial_value.checkpoint_position.restore_uid
+ initial_value = initial_value.wrapped_value
+
+ self._trainable = trainable
+ self._save_slice_info = None
+ # Store the graph key so optimizers know how to only retrieve variables from
+ # this graph.
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ with ops.init_scope():
+ self._in_graph_mode = not context.executing_eagerly()
+ with ops.name_scope(name, "Variable", []
+ if init_from_fn else [initial_value]) as name:
+ # pylint: disable=protected-access
+ handle_name = ops._name_from_scope_name(name)
+ shared_name = handle_name
+ if init_from_fn:
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ if self._in_graph_mode:
+ with ops.name_scope("Initializer"), ops.device(None):
+ initial_value = ops.convert_to_tensor(
+ initial_value(), name="initial_value", dtype=dtype)
+ self._handle = _eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=self._in_graph_mode)
+ self._shape = initial_value.get_shape()
+ else:
+ initial_value = initial_value()
+ with ops.name_scope("Initializer"):
+ initial_value = ops.convert_to_tensor(
+ initial_value, name="initial_value", dtype=dtype)
+ self._handle = _eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=False)
+ self._shape = initial_value.get_shape()
+ # pylint: enable=protected-access
+
+ # Or get the initial value from a Tensor or Python object.
+ else:
+ with ops.name_scope("Initializer"):
+ initial_value = ops.convert_to_tensor(
+ initial_value, name="initial_value", dtype=dtype)
+ # pylint: disable=protected-access
+ if (self._in_graph_mode and initial_value is not None and
+ initial_value.op._get_control_flow_context() is not None):
+ raise ValueError(
+ "Initializer for variable %s is from inside a control-flow "
+ "construct, such as a loop or conditional. When creating a "
+ "variable inside a loop or conditional, use a lambda as the "
+ "initializer." % name)
+ # pylint: enable=protected-access
+ self._handle = _eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=self._in_graph_mode)
+ self._shape = initial_value.get_shape()
+
+ self._unique_id = shared_name
+ self._initial_value = initial_value if self._in_graph_mode else None
+ self._handle_name = handle_name + ":0"
+ self._dtype = initial_value.dtype.base_dtype
+ self._constraint = constraint
+
+ if self._in_graph_mode:
+ with ops.name_scope("IsInitialized"):
+ self._is_initialized_op = (
+ resource_variable_ops.var_is_initialized_op(self._handle))
+ if initial_value is not None:
+ with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+ self._initializer_op = (
+ resource_variable_ops.assign_variable_op(
+ self._handle,
+ self._try_guard_against_uninitialized_dependencies(
+ initial_value),
+ name=n))
+ with ops.name_scope("Read"), ops.colocate_with(self._handle):
+ # Manually assign reads to the handle's device to avoid log
+ # messages.
+ with ops.device(self._handle.device):
+ value = self._read_variable_op()
+ self._graph_element = value
+ self._cached_value = None
+ else:
+ if initialize:
+ resource_variable_ops.assign_variable_op(self._handle,
+ initial_value)
+ self._is_initialized_op = None
+ self._initializer_op = None
+ self._graph_element = None
+ self._cached_value = None
+
+ self._handle_deleter = None
+ self._cached_shape_as_list = None
+
+
+@contextlib.contextmanager
+def parameter_server_scope(is_chief, ps_job_name, num_ps_tasks):
+ """Strategy to use parameter servers in eager.
+
+ Creates SharedVariable objects for variables created in this scope. These
+ SharedVariable objects will be placed round-robin on the parameter servers
+ specified by the ps_job_name and num_ps_tasks arguments.
+
+ To use parameter servers you need only to wrap your model initialization in
+ this scope:
+
+ ```
+ with tf.contrib.eager.parameter_server_scope(
+ is_chief, ps_job_name, num_ps_tasks):
+ my_model = tf.keras.Sequential([...]) # Or
+ input = tf.keras.Input(...)
+ ....
+ my_model = tf.keras.Model(input, output)
+ my_model.compile(...)
+ # or other usages of the model.
+ ```
+
+ Args:
+ is_chief: Boolean. Whether this worker is responsible for initializing
+ variables.
+ ps_job_name: The name of the ps job in this cluster.
+ num_ps_tasks: The number of ps tasks to use.
+
+ Yields:
+ a context manager.
+ """
+ # Note: capturing in a list to allow assignment.
+ ps_index = [0]
+
+ def variable_creator_scope(unused_next_creator, **kwargs):
+ kwargs["initialize"] = is_chief
+ with ops.device(
+ "/job:%s/task:%s" % (ps_job_name, ps_index[0] % num_ps_tasks)):
+ ps_index[0] += 1
+ v = SharedVariable(**kwargs)
+ if not is_chief:
+ while not resource_variable_ops.var_is_initialized_op(v.handle):
+ time.sleep(10)
+ return v
+
+ with variable_scope.variable_creator_scope(variable_creator_scope):
+ yield
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
index 13029db975..ba6fe9701d 100644
--- a/tensorflow/contrib/eager/python/remote_test.py
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -23,6 +23,7 @@ import os
import numpy as np
+from tensorflow.contrib.eager.python import parameter_server
from tensorflow.contrib.eager.python import remote
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
@@ -33,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@@ -120,6 +122,24 @@ class RemoteExecutionTest(test.TestCase):
y = math_ops.matmul(x1, x2)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+ def testParameterServer(self):
+ with parameter_server.parameter_server_scope(
+ is_chief=True, ps_job_name=JOB_NAME, num_ps_tasks=3):
+ v0 = variables.Variable([1.0], name="v0")
+ v1 = variables.Variable([2.0], name="v1")
+ v0.assign(v0 * v1)
+ self.assertAllEqual(v0.read_value(), [2.0])
+ self.assertAllEqual(v0.device,
+ "/job:%s/replica:0/task:0/device:CPU:0" % JOB_NAME)
+ self.assertAllEqual(v1.device,
+ "/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME)
+ v1.assign_add(v1)
+ # Simulate aliasing another variable of the same name as v1
+ with ops.device("/job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
+ v1_replica = parameter_server.SharedVariable(
+ [1.0], name="v1", initialize=False)
+ self.assertAllEqual(v1_replica.read_value(), [4.0])
+
@run_sync_and_async
def testSimpleWeightRead(self):
"""Basic remote eager weight read."""
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 437b3d965d..1ea00fb7f3 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -18,6 +18,7 @@ py_library(
":boosted_trees",
":dnn",
":dnn_linear_combined",
+ ":dnn_with_layer_annotations",
":early_stopping",
":export",
":exporter",
@@ -127,6 +128,42 @@ py_test(
)
py_library(
+ name = "dnn_with_layer_annotations",
+ srcs = ["python/estimator/dnn_with_layer_annotations.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py_no_contrib",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:head",
+ "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:optimizers",
+ ],
+)
+
+py_test(
+ name = "dnn_with_layer_annotations_test",
+ size = "medium",
+ srcs = ["python/estimator/dnn_with_layer_annotations_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "notsan", # b/67510291
+ ],
+ deps = [
+ ":dnn_with_layer_annotations",
+ "//tensorflow:tensorflow_py_no_contrib",
+ "//tensorflow/python/estimator:dnn",
+ "//tensorflow/python/estimator:dnn_testing_utils",
+ "//tensorflow/python/estimator:export_export",
+ "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/estimator:pandas_io",
+ "//tensorflow/python/estimator:prediction_keys",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "dnn_linear_combined",
srcs = ["python/estimator/dnn_linear_combined.py"],
srcs_version = "PY2AND3",
@@ -227,9 +264,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:summary",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator:exporter",
],
)
@@ -241,7 +276,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":exporter",
- "//tensorflow/python:platform",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:exporter",
],
@@ -446,7 +481,6 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
- "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
@@ -501,13 +535,10 @@ py_library(
srcs = ["python/estimator/saved_model_estimator.py"],
deps = [
":export",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:training",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export",
"//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/saved_model",
],
)
@@ -522,16 +553,7 @@ py_test(
deps = [
":export",
":saved_model_estimator",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:metrics",
- "//tensorflow/python:platform",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow:tensorflow_py_no_contrib",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:export_export",
"//tensorflow/python/estimator:export_output",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 258860f263..78914ecaca 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.contrib.estimator.python.estimator.baseline import *
from tensorflow.contrib.estimator.python.estimator.boosted_trees import *
from tensorflow.contrib.estimator.python.estimator.dnn import *
+from tensorflow.contrib.estimator.python.estimator.dnn_with_layer_annotations import *
from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import *
from tensorflow.contrib.estimator.python.estimator.early_stopping import *
from tensorflow.contrib.estimator.python.estimator.export import *
@@ -76,6 +77,8 @@ _allowed_symbols = [
'build_raw_supervised_input_receiver_fn',
'build_supervised_input_receiver_fn_from_input_fn',
'SavedModelEstimator'
+ 'DNNClassifierWithLayerAnnotations',
+ 'DNNRegressorWithLayerAnnotations',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 7ed77bcce6..a1f1c5f3d7 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+from tensorflow.python.estimator.canned import head as head_lib
def _validate_input_fn_and_repeat_dataset(train_input_fn):
@@ -33,7 +34,19 @@ def _validate_input_fn_and_repeat_dataset(train_input_fn):
return _input_fn
-class _BoostedTreesEstimator(estimator.Estimator):
+def _is_classification_head(head):
+ """Infers if the head is a classification head."""
+ # Check using all classification heads defined in canned/head.py. However, it
+ # is not a complete list - it does not check for other classification heads
+ # not defined in the head library.
+ # pylint: disable=protected-access
+ return isinstance(head,
+ (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss,
+ head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss))
+ # pylint: enable=protected-access
+
+
+class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: disable=protected-access
"""An Estimator for Tensorflow Boosted Trees models."""
def __init__(self,
@@ -96,9 +109,12 @@ class _BoostedTreesEstimator(estimator.Estimator):
negative gain). For pre and post pruning, you MUST provide
tree_complexity >0.
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
"""
- # pylint:disable=protected-access
# HParams for the model.
+ # pylint: disable=protected-access
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)
@@ -115,8 +131,14 @@ class _BoostedTreesEstimator(estimator.Estimator):
config=config)
super(_BoostedTreesEstimator, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
- # pylint:enable=protected-access
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=_is_classification_head(head))
+ # pylint: enable=protected-access
def boosted_trees_classifier_train_in_memory(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index b1581f3750..e23d9c0fc4 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -360,5 +360,79 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
[pred['predictions'] for pred in predictions])
+class BoostedTreesDebugOutputTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+
+ def testContribEstimatorThatDFCIsInPredictions(self):
+ # pylint:disable=protected-access
+ head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ head=head,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+ # pylint:enable=protected-access
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([1.8] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.070499420166015625,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: -0.53763031959533691,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: -0.51756942272186279,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: 0.1563495397567749,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: 0.96934974193572998,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == predictions.
+ expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+ [2.01968288], [2.83268309]]
+ predictions = [
+ [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_predictions, predictions)
+
+ # Test when user doesn't include bias or dfc in predict_keys.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['predictions'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('predictions' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
new file mode 100644
index 0000000000..3fd9f12c61
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -0,0 +1,429 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Deep Neural Network estimators with layer annotations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import pickle
+
+from google.protobuf.any_pb2 import Any
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.saved_model import utils as saved_model_utils
+
+
+class LayerAnnotationsCollectionNames(object):
+ """Names for the collections containing the annotations."""
+
+ UNPROCESSED_FEATURES = 'layer_annotations/unprocessed_features'
+ PROCESSED_FEATURES = 'layer_annotatons/processed_features'
+ FEATURE_COLUMNS = 'layer_annotations/feature_columns'
+
+ @classmethod
+ def keys(cls, collection_name):
+ return '%s/keys' % collection_name
+
+ @classmethod
+ def values(cls, collection_name):
+ return '%s/values' % collection_name
+
+
+def serialize_feature_column(feature_column):
+ if isinstance(feature_column, feature_column_lib._EmbeddingColumn): # pylint: disable=protected-access
+ # We can't pickle nested functions, and we don't need the value of
+ # layer_creator in most cases anyway, so just discard its value.
+ args = feature_column._asdict()
+ args['layer_creator'] = None
+ temp = type(feature_column)(**args)
+ return pickle.dumps(temp)
+ return pickle.dumps(feature_column)
+
+
+def _to_any_wrapped_tensor_info(tensor):
+ """Converts a `Tensor` to a `TensorInfo` wrapped in a proto `Any`."""
+ any_buf = Any()
+ tensor_info = saved_model_utils.build_tensor_info(tensor)
+ any_buf.Pack(tensor_info)
+ return any_buf
+
+
+def make_input_layer_with_layer_annotations(original_input_layer):
+ """Make an input_layer replacement function that adds layer annotations."""
+
+ def input_layer_with_layer_annotations(features,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None,
+ cols_to_output_tensors=None):
+ """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+
+ Generally a single example in training data is described with
+ FeatureColumns.
+ At the first layer of the model, this column oriented data should be
+ converted
+ to a single `Tensor`.
+
+ This is like tf.feature_column.input_layer, except with added
+ Integrated-Gradient annotations.
+
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
+ on corresponding `_FeatureColumn`.
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `_DenseColumn` such as `numeric_column`, `embedding_column`,
+ `bucketized_column`, `indicator_column`. If you have categorical
+ features, you can wrap them with an `embedding_column` or
+ `indicator_column`.
+ weight_collections: A list of collection names to which the Variable will
+ be added. Note that variables will also be added to collections
+ `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ cols_to_vars: If not `None`, must be a dictionary that will be filled with
+ a mapping from `_FeatureColumn` to list of `Variable`s. For example,
+ after the call, we might have cols_to_vars = {_EmbeddingColumn(
+ categorical_column=_HashedCategoricalColumn( key='sparse_feature',
+ hash_bucket_size=5, dtype=tf.string), dimension=10): [<tf.Variable
+ 'some_variable:0' shape=(5, 10), <tf.Variable 'some_variable:1'
+ shape=(5, 10)]} If a column creates no variables, its value will be an
+ empty list.
+ cols_to_output_tensors: If not `None`, must be a dictionary that will be
+ filled with a mapping from '_FeatureColumn' to the associated output
+ `Tensor`s.
+
+ Returns:
+ A `Tensor` which represents input layer of a model. Its shape
+ is (batch_size, first_layer_dimension) and its dtype is `float32`.
+ first_layer_dimension is determined based on given `feature_columns`.
+
+ Raises:
+ ValueError: features and feature_columns have different lengths.
+ """
+
+ local_cols_to_output_tensors = {}
+ input_layer = original_input_layer(
+ features=features,
+ feature_columns=feature_columns,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ cols_to_vars=cols_to_vars,
+ cols_to_output_tensors=local_cols_to_output_tensors)
+
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors = local_cols_to_output_tensors
+
+ # Annotate features.
+ # These are the parsed Tensors, before embedding.
+
+ # Only annotate features used by FeatureColumns.
+ # We figure which ones are used by FeatureColumns by creating a parsing
+ # spec and looking at the keys.
+ spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ for key in spec.keys():
+ tensor = ops.convert_to_tensor(features[key])
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ # Annotate feature columns.
+ for column in feature_columns:
+ # TODO(cyfoo): Find a better way to serialize and deserialize
+ # _FeatureColumn.
+ ops.add_to_collection(LayerAnnotationsCollectionNames.FEATURE_COLUMNS,
+ serialize_feature_column(column))
+
+ for column, tensor in local_cols_to_output_tensors.items():
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.keys(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES), column.name)
+ ops.add_to_collection(
+ LayerAnnotationsCollectionNames.values(
+ LayerAnnotationsCollectionNames.PROCESSED_FEATURES),
+ _to_any_wrapped_tensor_info(tensor))
+
+ return input_layer
+
+ return input_layer_with_layer_annotations
+
+
+@contextlib.contextmanager
+def _monkey_patch(module, function, replacement):
+ old_function = getattr(module, function)
+ setattr(module, function, replacement)
+ yield
+ setattr(module, function, old_function)
+
+
+def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
+ hidden_units,
+ feature_columns,
+ model_dir=None,
+ n_classes=2,
+ weight_column=None,
+ label_vocabulary=None,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None,
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM):
+ """A classifier for TensorFlow DNN models with layer annotations.
+
+ This classifier is fuctionally identical to estimator.DNNClassifier as far as
+ training and evaluating models is concerned. The key difference is that this
+ classifier adds additional layer annotations, which can be used for computing
+ Integrated Gradients.
+
+ Integrated Gradients is a method for attributing a classifier's predictions
+ to its input features (https://arxiv.org/pdf/1703.01365.pdf). Given an input
+ instance, the method assigns attribution scores to individual features in
+ proportion to the feature's importance to the classifier's prediction.
+
+ See estimator.DNNClassifer for example code for training and evaluating models
+ using this classifier.
+
+ This classifier is checkpoint-compatible with estimator.DNNClassifier and
+ therefore the following should work seamlessly:
+
+ # Instantiate ordinary estimator as usual.
+ estimator = tf.estimator.DNNClassifier(
+ config, feature_columns, hidden_units, ...)
+
+ # Train estimator, export checkpoint.
+ tf.estimator.train_and_evaluate(estimator, ...)
+
+ # Instantiate estimator with annotations with the same configuration as the
+ # ordinary estimator.
+ estimator_with_annotations = (
+ tf.contrib.estimator.DNNClassifierWithLayerAnnotations(
+ config, feature_columns, hidden_units, ...))
+
+ # Call export_savedmodel with the same arguments as the ordinary estimator,
+ # using the checkpoint produced for the ordinary estimator.
+ estimator_with_annotations.export_saved_model(
+ export_dir_base, serving_input_receiver, ...
+ checkpoint_path='/path/to/ordinary/estimator/checkpoint/model.ckpt-1234')
+
+ Args:
+ hidden_units: Iterable of number hidden units per layer. All layers are
+ fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second
+ one has 32.
+ feature_columns: An iterable containing all the feature columns used by the
+ model. All items in the set should be instances of classes derived from
+ `_FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can also
+ be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ n_classes: Number of label classes. Defaults to 2, namely binary
+ classification. Must be > 1.
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
+ weight_column.normalizer_fn is applied on it to get weight tensor.
+ label_vocabulary: A list of strings represents possible label values. If
+ given, labels must be string type and have any value in
+ `label_vocabulary`. If it is not given, that means labels are already
+ encoded as integer or float within [0, 1] for `n_classes=2` and encoded as
+ integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there
+ will be errors if vocabulary is not provided and labels are string.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+ to Adagrad optimizer.
+ activation_fn: Activation function applied to each layer. If `None`, will
+ use `tf.nn.relu`.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Optional. Partitioner for input layer. Defaults to
+ `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+ warm_start_from: A string filepath to a checkpoint to warm-start from, or a
+ `WarmStartSettings` object to fully configure warm-starting. If the
+ string filepath is provided instead of a `WarmStartSettings`, then all
+ weights are warm-started, and it is assumed that vocabularies and Tensor
+ names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+ reduce training loss over batch. Defaults to `SUM`.
+
+ Returns:
+ DNNClassifier with layer annotations.
+ """
+
+ original = dnn.DNNClassifier(
+ hidden_units=hidden_units,
+ feature_columns=feature_columns,
+ model_dir=model_dir,
+ n_classes=n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config,
+ warm_start_from=warm_start_from,
+ loss_reduction=loss_reduction)
+
+ def _model_fn(features, labels, mode, config):
+ with _monkey_patch(
+ feature_column_lib, 'input_layer',
+ make_input_layer_with_layer_annotations(
+ feature_column_lib.input_layer)):
+ return original.model_fn(features, labels, mode, config)
+
+ return estimator.Estimator(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ warm_start_from=warm_start_from)
+
+
+def DNNRegressorWithLayerAnnotations( # pylint: disable=invalid-name
+ hidden_units,
+ feature_columns,
+ model_dir=None,
+ label_dimension=1,
+ weight_column=None,
+ optimizer='Adagrad',
+ activation_fn=nn.relu,
+ dropout=None,
+ input_layer_partitioner=None,
+ config=None,
+ warm_start_from=None,
+ loss_reduction=losses.Reduction.SUM,
+):
+ """A regressor for TensorFlow DNN models with layer annotations.
+
+ This regressor is fuctionally identical to estimator.DNNRegressor as far as
+ training and evaluating models is concerned. The key difference is that this
+ classifier adds additional layer annotations, which can be used for computing
+ Integrated Gradients.
+
+ Integrated Gradients is a method for attributing a classifier's predictions
+ to its input features (https://arxiv.org/pdf/1703.01365.pdf). Given an input
+ instance, the method assigns attribution scores to individual features in
+ proportion to the feature's importance to the classifier's prediction.
+
+ See estimator.DNNRegressor for example code for training and evaluating models
+ using this regressor.
+
+ This regressor is checkpoint-compatible with estimator.DNNRegressor and
+ therefore the following should work seamlessly:
+
+ # Instantiate ordinary estimator as usual.
+ estimator = tf.estimator.DNNRegressor(
+ config, feature_columns, hidden_units, ...)
+
+ # Train estimator, export checkpoint.
+ tf.estimator.train_and_evaluate(estimator, ...)
+
+ # Instantiate estimator with annotations with the same configuration as the
+ # ordinary estimator.
+ estimator_with_annotations = (
+ tf.contrib.estimator.DNNRegressorWithLayerAnnotations(
+ config, feature_columns, hidden_units, ...))
+
+ # Call export_savedmodel with the same arguments as the ordinary estimator,
+ # using the checkpoint produced for the ordinary estimator.
+ estimator_with_annotations.export_saved_model(
+ export_dir_base, serving_input_receiver, ...
+ checkpoint_path='/path/to/ordinary/estimator/checkpoint/model.ckpt-1234')
+
+ Args:
+ hidden_units: Iterable of number hidden units per layer. All layers are
+ fully connected. Ex. `[64, 32]` means first layer has 64 nodes and second
+ one has 32.
+ feature_columns: An iterable containing all the feature columns used by the
+ model. All items in the set should be instances of classes derived from
+ `_FeatureColumn`.
+ model_dir: Directory to save model parameters, graph and etc. This can also
+ be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ label_dimension: Number of regression targets per example. This is the size
+ of the last dimension of the labels and logits `Tensor` objects
+ (typically, these have shape `[batch_size, label_dimension]`).
+ weight_column: A string or a `_NumericColumn` created by
+ `tf.feature_column.numeric_column` defining feature column representing
+ weights. It is used to down weight or boost examples during training. It
+ will be multiplied by the loss of the example. If it is a string, it is
+ used as a key to fetch weight tensor from the `features`. If it is a
+ `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then
+ weight_column.normalizer_fn is applied on it to get weight tensor.
+ optimizer: An instance of `tf.Optimizer` used to train the model. Defaults
+ to Adagrad optimizer.
+ activation_fn: Activation function applied to each layer. If `None`, will
+ use `tf.nn.relu`.
+ dropout: When not `None`, the probability we will drop out a given
+ coordinate.
+ input_layer_partitioner: Optional. Partitioner for input layer. Defaults to
+ `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
+ config: `RunConfig` object to configure the runtime settings.
+ warm_start_from: A string filepath to a checkpoint to warm-start from, or a
+ `WarmStartSettings` object to fully configure warm-starting. If the
+ string filepath is provided instead of a `WarmStartSettings`, then all
+ weights are warm-started, and it is assumed that vocabularies and Tensor
+ names are unchanged.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
+ reduce training loss over batch. Defaults to `SUM`.
+
+ Returns:
+ DNNRegressor with layer annotations.
+ """
+
+ original = dnn.DNNRegressor(
+ hidden_units=hidden_units,
+ feature_columns=feature_columns,
+ model_dir=model_dir,
+ label_dimension=label_dimension,
+ weight_column=weight_column,
+ optimizer=optimizer,
+ activation_fn=activation_fn,
+ dropout=dropout,
+ input_layer_partitioner=input_layer_partitioner,
+ config=config,
+ warm_start_from=warm_start_from,
+ loss_reduction=loss_reduction,
+ )
+
+ def _model_fn(features, labels, mode, config):
+ with _monkey_patch(
+ feature_column_lib, 'input_layer',
+ make_input_layer_with_layer_annotations(
+ feature_column_lib.input_layer)):
+ return original.model_fn(features, labels, mode, config)
+
+ return estimator.Estimator(
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ warm_start_from=warm_start_from)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py
new file mode 100644
index 0000000000..2fe3d4c72e
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations_test.py
@@ -0,0 +1,611 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dnn_with_layer_annotations.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.contrib.estimator.python.estimator import dnn_with_layer_annotations
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import dnn
+from tensorflow.python.estimator.canned import dnn_testing_utils
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.estimator.inputs import pandas_io
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import queue_runner
+
+try:
+ # pylint: disable=g-import-not-at-top
+ import pandas as pd
+ HAS_PANDAS = True
+except IOError:
+ # Pandas writes a temporary file during import. If it fails, don't use pandas.
+ HAS_PANDAS = False
+except ImportError:
+ HAS_PANDAS = False
+
+
+def _dnn_classifier_fn(*args, **kwargs):
+ return dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations(
+ *args, **kwargs)
+
+
+class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn,
+ _dnn_regressor_fn)
+
+
+class DNNWithLayerAnnotationsClassifierEvaluateTest(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+class DNNClassifierWithLayerAnnotationsPredictTest(
+ dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+class DNNClassifierWithLayerAnnotationsTrainTest(
+ dnn_testing_utils.BaseDNNClassifierTrainTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn)
+
+
+def _dnn_regressor_fn(*args, **kwargs):
+ return dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations(
+ *args, **kwargs)
+
+
+class DNNWithLayerAnnotationsTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def _getLayerAnnotationCollection(self, graph, collection_name):
+ keys = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames.keys(
+ collection_name))
+ values = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames.values(
+ collection_name))
+ if len(keys) != len(values):
+ raise ValueError('keys and values should have same length. lengths were: '
+ '%d and %d, and elements were %s and %s' %
+ (len(keys), len(values), keys, values))
+ return dict(zip(keys, values))
+
+ def _testAnnotationsPresentForEstimator(self, estimator_class):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(1,)),
+ feature_column.embedding_column(
+ feature_column.categorical_column_with_vocabulary_list(
+ 'y', vocabulary_list=['a', 'b', 'c']),
+ dimension=3)
+ ]
+ estimator = estimator_class(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ model_dir=self._model_dir)
+ model_fn = estimator.model_fn
+
+ graph = ops.Graph()
+ with graph.as_default():
+ model_fn({
+ 'x': array_ops.constant([1.0]),
+ 'y': array_ops.constant(['a'])
+ }, {},
+ model_fn_lib.ModeKeys.PREDICT,
+ config=None)
+
+ unprocessed_features = self._getLayerAnnotationCollection(
+ graph, dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .UNPROCESSED_FEATURES)
+ processed_features = self._getLayerAnnotationCollection(
+ graph, dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .PROCESSED_FEATURES)
+ feature_columns = graph.get_collection(
+ dnn_with_layer_annotations.LayerAnnotationsCollectionNames
+ .FEATURE_COLUMNS)
+
+ self.assertItemsEqual(unprocessed_features.keys(), ['x', 'y'])
+ self.assertEqual(2, len(processed_features.keys()))
+ self.assertEqual(2, len(feature_columns))
+
+ def testAnnotationsPresentForClassifier(self):
+ self._testAnnotationsPresentForEstimator(
+ dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations)
+
+ def testAnnotationsPresentForRegressor(self):
+ self._testAnnotationsPresentForEstimator(
+ dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations)
+
+ def _testCheckpointCompatibleWithNonAnnotatedEstimator(
+ self, train_input_fn, predict_input_fn, non_annotated_class,
+ annotated_class, prediction_key, estimator_args):
+ input_dimension = 2
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ estimator = non_annotated_class(
+ model_dir=self._model_dir,
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ **estimator_args)
+
+ estimator.train(train_input_fn, steps=10)
+
+ predictions = np.array(
+ [x[prediction_key] for x in estimator.predict(predict_input_fn)])
+
+ annotated_estimator = annotated_class(
+ model_dir=self._model_dir,
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ warm_start_from=self._model_dir,
+ **estimator_args)
+
+ annotated_predictions = np.array([
+ x[prediction_key] for x in annotated_estimator.predict(predict_input_fn)
+ ])
+
+ self.assertAllEqual(predictions.shape, annotated_predictions.shape)
+ for i, (a, b) in enumerate(
+ zip(predictions.flatten(), annotated_predictions.flatten())):
+ self.assertAlmostEqual(a, b, msg='index=%d' % i)
+
+ def testCheckpointCompatibleForClassifier(self):
+ n_classes = 2
+ input_dimension = 2
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ x_data = data.reshape(batch_size, input_dimension)
+ y_data = np.reshape(
+ np.rint(data[:batch_size]).astype(np.int64), (batch_size, 1))
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+ self._testCheckpointCompatibleWithNonAnnotatedEstimator(
+ train_input_fn,
+ predict_input_fn,
+ dnn.DNNClassifier,
+ dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations,
+ prediction_key=prediction_keys.PredictionKeys.PROBABILITIES,
+ estimator_args={'n_classes': n_classes})
+
+ def testCheckpointCompatibleForRegressor(self):
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, batch_size=batch_size, shuffle=False)
+
+ self._testCheckpointCompatibleWithNonAnnotatedEstimator(
+ train_input_fn,
+ predict_input_fn,
+ dnn.DNNRegressor,
+ dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations,
+ prediction_key=prediction_keys.PredictionKeys.PREDICTIONS,
+ estimator_args={'label_dimension': label_dimension})
+
+
+class DNNRegressorWithLayerAnnotationsEvaluateTest(
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+class DNNRegressorWithLayerAnnotationsPredictTest(
+ dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+class DNNRegressorWithLayerAnnotationsTrainTest(
+ dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn)
+
+
+def _queue_parsed_features(feature_map):
+ tensors_to_enqueue = []
+ keys = []
+ for key, tensor in six.iteritems(feature_map):
+ keys.append(key)
+ tensors_to_enqueue.append(tensor)
+ queue_dtypes = [x.dtype for x in tensors_to_enqueue]
+ input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
+ queue_runner.add_queue_runner(
+ queue_runner.QueueRunner(input_queue,
+ [input_queue.enqueue(tensors_to_enqueue)]))
+ dequeued_tensors = input_queue.dequeue()
+ return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
+
+
+class DNNRegressorWithLayerAnnotationsIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = dnn_with_layer_annotations.DNNRegressorWithLayerAnnotations(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ label_dimension=label_dimension,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array([
+ x[prediction_keys.PredictionKeys.PREDICTIONS]
+ for x in est.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, label_dimension), predictions.shape)
+
+ # EXPORT
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data},
+ y=data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, y=data, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': data}, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+ def test_pandas_input_fn(self):
+ """Tests complete flow with pandas_input_fn."""
+ if not HAS_PANDAS:
+ return
+ label_dimension = 1
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size, dtype=np.float32)
+ x = pd.DataFrame({'x': data})
+ y = pd.Series(data)
+ train_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+ eval_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, shuffle=False)
+ predict_input_fn = pandas_io.pandas_input_fn(
+ x=x, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+ def test_input_fn_from_parse_example(self):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ label_dimension = 2
+ batch_size = 10
+ data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, label_dimension)
+
+ serialized_examples = []
+ for datum in data:
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'x':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ 'y':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'x': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+ }
+
+ def _train_input_fn():
+ feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _eval_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _predict_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ features.pop('y')
+ return features, None
+
+ self._test_complete_flow(
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ input_dimension=label_dimension,
+ label_dimension=label_dimension,
+ batch_size=batch_size)
+
+
+class DNNClassifierWithLayerAnnotationsIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _as_label(self, data_in_float):
+ return np.rint(data_in_float).astype(np.int64)
+
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size):
+ feature_columns = [
+ feature_column.numeric_column('x', shape=(input_dimension,))
+ ]
+ est = dnn_with_layer_annotations.DNNClassifierWithLayerAnnotations(
+ hidden_units=(2, 2),
+ feature_columns=feature_columns,
+ n_classes=n_classes,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+
+ # PREDICT
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PROBABILITIES]
+ for x in est.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
+
+ # EXPORT
+ feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+ export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+ serving_input_receiver_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ n_classes = 3
+ input_dimension = 2
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ x_data = data.reshape(batch_size, input_dimension)
+ y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
+ # learn y = x
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data},
+ y=y_data,
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': x_data}, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+ def test_pandas_input_fn(self):
+ """Tests complete flow with pandas_input_fn."""
+ if not HAS_PANDAS:
+ return
+ input_dimension = 1
+ n_classes = 3
+ batch_size = 10
+ data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)
+ x = pd.DataFrame({'x': data})
+ y = pd.Series(self._as_label(data))
+ train_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+ eval_input_fn = pandas_io.pandas_input_fn(
+ x=x, y=y, batch_size=batch_size, shuffle=False)
+ predict_input_fn = pandas_io.pandas_input_fn(
+ x=x, batch_size=batch_size, shuffle=False)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+ def test_input_fn_from_parse_example(self):
+ """Tests complete flow with input_fn constructed from parse_example."""
+ input_dimension = 2
+ n_classes = 3
+ batch_size = 10
+ data = np.linspace(
+ 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
+ data = data.reshape(batch_size, input_dimension)
+
+ serialized_examples = []
+ for datum in data:
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'x':
+ feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=datum)),
+ 'y':
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(
+ value=self._as_label(datum[:1]))),
+ }))
+ serialized_examples.append(example.SerializeToString())
+
+ feature_spec = {
+ 'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
+ 'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
+ }
+
+ def _train_input_fn():
+ feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _eval_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ labels = features.pop('y')
+ return features, labels
+
+ def _predict_input_fn():
+ feature_map = parsing_ops.parse_example(
+ input_lib.limit_epochs(serialized_examples, num_epochs=1),
+ feature_spec)
+ features = _queue_parsed_features(feature_map)
+ features.pop('y')
+ return features, None
+
+ self._test_complete_flow(
+ train_input_fn=_train_input_fn,
+ eval_input_fn=_eval_input_fn,
+ predict_input_fn=_predict_input_fn,
+ input_dimension=input_dimension,
+ n_classes=n_classes,
+ batch_size=batch_size)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/early_stopping.py b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
index 3eab21d5ac..cafe8279c7 100644
--- a/tensorflow/contrib/estimator/python/estimator/early_stopping.py
+++ b/tensorflow/contrib/estimator/python/estimator/early_stopping.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import operator
import os
@@ -56,6 +57,13 @@ def make_early_stopping_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
should_stop_fn: `callable`, function that takes no arguments and returns a
@@ -108,6 +116,13 @@ def stop_if_higher_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -157,6 +172,13 @@ def stop_if_lower_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -206,6 +228,13 @@ def stop_if_no_increase_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -256,6 +285,13 @@ def stop_if_no_decrease_hook(estimator,
tf.estimator.train_and_evaluate(estimator, train_spec, ...)
```
+ Caveat: Current implementation supports early-stopping both training and
+ evaluation in local mode. In distributed mode, training can be stopped but
+ evaluation (where it's a separate job) will indefinitely wait for new model
+ checkpoints to evaluate, so you will need other means to detect and stop it.
+ Early-stopping evaluation in distributed mode requires changes in
+ `train_and_evaluate` API and will be addressed in a future revision.
+
Args:
estimator: A `tf.estimator.Estimator` instance.
metric_name: `str`, metric to track. "loss", "accuracy", etc.
@@ -306,7 +342,8 @@ def read_eval_metrics(eval_dir):
metrics[value.tag] = value.simple_value
if metrics:
eval_metrics_dict[event.step] = metrics
- return eval_metrics_dict
+ return collections.OrderedDict(
+ sorted(eval_metrics_dict.items(), key=lambda t: t[0]))
def _stop_if_threshold_crossed_hook(estimator, metric_name, threshold,
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index 66c46e66b7..49f7bbd320 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -53,6 +53,7 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
```
Current limitations of this approach are:
+
* It doesn't support multi-node distributed mode.
* It doesn't support saveable objects other than variables (such as boosted
tree support)
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 9e1f14f990..e344d7a23b 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -64,7 +64,6 @@ tf_custom_op_py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/feature_column:feature_column_py",
"//third_party/py/numpy",
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index b1820c10c8..9b0b9b1e1b 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -186,7 +186,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_2x2_with_partial_expected_shape(self):
- with self.test_session():
+ with self.cached_session():
value = [[42, 43], [44, 45]]
actual_shape = [2, 2]
tensor = constant_op.constant(value, shape=actual_shape)
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index 0f0813c07f..9725233e7f 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -111,7 +111,6 @@ tf_gen_op_wrapper_py(
cuda_py_test(
name = "fused_conv2d_bias_activation_op_test",
- size = "large",
srcs = ["python/ops/fused_conv2d_bias_activation_op_test.py"],
additional_deps = [
":fused_conv_py",
@@ -130,14 +129,12 @@ cuda_py_test(
"//tensorflow/python:variables",
],
tags = [
- "manual",
- "requires_cudnn6",
+ "requires-gpu-sm70",
],
)
cuda_py_test(
name = "fused_conv2d_bias_activation_benchmark",
- size = "large",
srcs = ["python/ops/fused_conv2d_bias_activation_benchmark.py"],
additional_deps = [
":fused_conv_py",
@@ -155,7 +152,6 @@ cuda_py_test(
],
main = "python/ops/fused_conv2d_bias_activation_benchmark.py",
tags = [
- "manual",
- "requires_cudnn6",
+ "requires-gpu-sm70",
],
)
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 716bb87e38..e9e6464d06 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -497,7 +497,8 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
FORMAT_OIHW, filter_param.shape(), FORMAT_HWIO),
&maybe_transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter_param.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter_param.tensor<T, 4>()),
To32Bit(maybe_transformed_filter.tensor<T, 4>()));
filter = &maybe_transformed_filter;
}
diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
index 0185ef662c..4894298694 100644
--- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
+++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py
@@ -265,10 +265,10 @@ class FusedConv2DBiasActivationTest(test.TestCase):
tensors = []
for (data_format, use_gpu) in GetTestConfigs():
tensors.append(_SetupVal(data_format, use_gpu))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = sess.run(tensors)
for i in range(1, len(values)):
- self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
+ self.assertAllClose(values[0], values[i], rtol=1e-3, atol=1e-3)
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, bias, strides,
padding):
@@ -282,7 +282,7 @@ class FusedConv2DBiasActivationTest(test.TestCase):
data_format, filter_format, dtype)
tensors.append(result)
ref_tensors.append(expected)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = sess.run(tensors)
ref_values = sess.run(ref_tensors)
for i in range(len(tensors)):
@@ -493,7 +493,7 @@ class FusedConv2DBiasActivationTest(test.TestCase):
if gpu_only and not test.is_gpu_available():
tf_logging.info("Skipping OpEdgeCases tests.")
return
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Illegal strides.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Convolutional strides are not supported in "
@@ -873,9 +873,7 @@ class FusedConvInt8Tests(test.TestCase):
with self.test_session(use_gpu=True) as sess:
actual_y, expected_y = sess.run([actual, expected])
- tf_logging.info("actual_y = ", actual_y)
- tf_logging.info("expected_y = ", expected_y)
- self.assertTrue(np.array_equal(actual_y, expected_y))
+ self.assertAllClose(actual_y, expected_y, rtol=0, atol=1)
def testFusedConvInt8(self):
if not test.is_gpu_available(
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index d389748374..8bc4db8424 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -773,9 +773,9 @@ def mutual_information_penalty(
structured_generator_inputs: A list of Tensors representing the random noise
that must have high mutual information with the generator output. List
length should match `predicted_distributions`.
- predicted_distributions: A list of tf.Distributions. Predicted by the
- recognizer, and used to evaluate the likelihood of the structured noise.
- List length should match `structured_generator_inputs`.
+ predicted_distributions: A list of `tfp.distributions.Distribution`s.
+ Predicted by the recognizer, and used to evaluate the likelihood of the
+ structured noise. List length should match `structured_generator_inputs`.
weights: Optional `Tensor` whose rank is either 0, or the same dimensions as
`structured_generator_inputs`.
scope: The scope for the operations performed in computing the loss.
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index a462b68e28..b9ac1bf151 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -91,9 +91,9 @@ class InfoGANModel(
structured_generator_inputs: A list of Tensors representing the random noise
that must have high mutual information with the generator output. List
length should match `predicted_distributions`.
- predicted_distributions: A list of tf.Distributions. Predicted by the
- recognizer, and used to evaluate the likelihood of the structured noise.
- List length should match `structured_generator_inputs`.
+ predicted_distributions: A list of `tfp.distributions.Distribution`s.
+ Predicted by the recognizer, and used to evaluate the likelihood of the
+ structured noise. List length should match `structured_generator_inputs`.
discriminator_and_aux_fn: The original discriminator function that returns
a tuple of (logits, `predicted_distributions`).
"""
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 58f348034f..64d6706199 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -399,7 +399,7 @@ class StarGANModelTest(test.TestCase):
target_tensor = train._generate_stargan_random_domain_target(
batch_size, domain_numbers)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
targets = sess.run(target_tensor)
self.assertTupleEqual((batch_size, domain_numbers), targets.shape)
for target in targets:
@@ -676,7 +676,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase):
self.assertIsInstance(model_loss, namedtuples.GANLoss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index 726f74c7b7..bb06f1c41c 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -138,6 +138,8 @@ class GdrMemoryManager : public RemoteMemoryManager {
Device* device, DeviceContext* device_context, bool on_host,
StatusCallback done) override;
+ static void RegMemVisitors();
+
protected:
Status CreateEndpoint(const string& host, const string& port,
RdmaEndpointPtr& endpoint);
@@ -183,35 +185,51 @@ class GdrMemoryManager : public RemoteMemoryManager {
TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
};
-// TODO(byronyi): remove this class and its registration when the default
-// cpu_allocator() returns visitable allocator, or cpu_allocator() is no
-// longer in use.
-class BFCGdrAllocator : public BFCAllocator {
- public:
- BFCGdrAllocator()
- : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36,
- true, "cpu_gdr_bfc") {}
-};
-class BFCGdrAllocatorFactory : public AllocatorFactory {
- public:
- Allocator* CreateAllocator() override { return new BFCGdrAllocator; }
-
- virtual SubAllocator* CreateSubAllocator(int numa_node) {
- return new BasicCPUAllocator(numa_node);
- }
-};
-
-REGISTER_MEM_ALLOCATOR("BFCGdrAllocator", 102, BFCGdrAllocatorFactory);
-
GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
: host_(host),
port_(port),
listening_(nullptr, EndpointDeleter),
stopped_(true),
- next_key_(0) {}
+ next_key_(0) {
+ static std::once_flag flag;
+ std::call_once(flag, []() { RegMemVisitors(); });
+}
GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
+/*static*/ void GdrMemoryManager::RegMemVisitors() {
+ SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
+ size_t num_bytes) {
+ GdrMemoryManager::Singleton().InsertMemoryRegion(
+ ptr, num_bytes, strings::StrCat("CPU:", numa_node));
+ };
+ SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
+ size_t num_bytes) {
+ GdrMemoryManager::Singleton().EvictMemoryRegion(ptr, num_bytes);
+ };
+ ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
+ ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
+
+#if GOOGLE_CUDA
+ if (IsGDRAvailable()) {
+ int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
+
+ // Note we don't free allocated GPU memory so there is no free visitor
+ SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
+ size_t num_bytes) {
+ RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+ ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
+ };
+ GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
+ cuda_alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
+ alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
+ LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
+ }
+#endif // GOOGLE_CUDA
+}
+
Status GdrMemoryManager::Init() {
epfd_ = epoll_create1(0);
if (epfd_ == -1) {
@@ -271,48 +289,6 @@ Status GdrMemoryManager::Init() {
"cannot add server to epoll");
}
- Allocator* allocators[] = {
-#if GOOGLE_CUDA
- GPUProcessState::singleton()->GetCUDAHostAllocator(0),
-#endif // GOOGLE_CUDA
- ProcessState::singleton()->GetCPUAllocator(0),
- cpu_allocator(),
- };
-
- using namespace std::placeholders;
- VisitableAllocator::Visitor alloc_visitor =
- std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
- VisitableAllocator::Visitor free_visitor =
- std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
-
- std::set<Allocator*> instrumented_;
-
- // Host memory allocators
- for (Allocator* allocator : allocators) {
- auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
- CHECK(visitable_allocator)
- << "is not visitable for instrumentation" << allocator->Name();
- // Make sure we don't instrument the same allocator twice
- if (instrumented_.find(allocator) == std::end(instrumented_)) {
- visitable_allocator->AddAllocVisitor(alloc_visitor);
- visitable_allocator->AddFreeVisitor(free_visitor);
- instrumented_.insert(allocator);
- LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
- }
- }
-
-#if GOOGLE_CUDA
- VisitableAllocator::Visitor cuda_alloc_visitor =
- std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
- if (IsGDRAvailable()) {
- // Note we don't free allocated GPU memory so there is no free visitor
- int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
- GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
- cuda_alloc_visitor);
- LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
- }
-#endif // GOOGLE_CUDA
-
return Status::OK();
}
diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
index fed8a771cc..27aed091c2 100644
--- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
+++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py
@@ -233,7 +233,7 @@ class GridRNNCellTest(test.TestCase):
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
def testGrid2LSTMCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -261,7 +261,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid2BasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -292,7 +292,7 @@ class GridRNNCellTest(test.TestCase):
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellTied(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([2, 2])
@@ -323,7 +323,7 @@ class GridRNNCellTest(test.TestCase):
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
def testGrid2BasicRNNCellWithRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -348,7 +348,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid1LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
x = array_ops.zeros([1, 3])
@@ -410,7 +410,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGrid3LSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -455,7 +455,7 @@ class GridRNNCellTest(test.TestCase):
"""
def testGridRNNEdgeCasesLikeRelu(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([3, 2])
@@ -481,7 +481,7 @@ class GridRNNCellTest(test.TestCase):
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
def testGridRNNEdgeCasesNoOutput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -541,7 +541,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -581,7 +581,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -623,7 +623,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -663,7 +663,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape(), (3, num_units))
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((batch_size, input_size))
@@ -700,7 +700,7 @@ class GridRNNCellTest(test.TestCase):
self.assertEqual(out[0].get_shape()[1], num_units)
self.assertEqual(out[0].dtype, inp.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
input_value = np.ones((3, input_size))
@@ -715,7 +715,7 @@ class GridRNNCellTest(test.TestCase):
def testGrid2LSTMCellLegacy(self):
"""Test for legacy case (when state_is_tuple=False)."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
'root', initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
diff --git a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
index d796e43d87..f7f1189bb9 100644
--- a/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
+++ b/tensorflow/contrib/hadoop/python/kernel_tests/hadoop_test.py
@@ -51,7 +51,7 @@ class SequenceFileDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for _ in range(num_repeats): # Dataset is repeated.
for i in range(25): # 25 records.
diff --git a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
index 6e0e628655..bf398b838d 100644
--- a/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
+++ b/tensorflow/contrib/hadoop/python/ops/hadoop_dataset_ops.py
@@ -19,14 +19,14 @@ from __future__ import print_function
from tensorflow.contrib.hadoop.python.ops import gen_dataset_ops
from tensorflow.contrib.hadoop.python.ops import hadoop_op_loader # pylint: disable=unused-import
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class SequenceFileDataset(Dataset):
+class SequenceFileDataset(dataset_ops.DatasetSource):
"""A Sequence File Dataset that reads the sequence file."""
def __init__(self, filenames):
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
index 9ed017592a..f44edaa14c 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class InputPipelineOpsTest(test.TestCase):
def testObtainNext(self):
- with self.test_session():
+ with self.cached_session():
var = state_ops.variable_op([], dtypes.int64)
state_ops.assign(var, -1).op.run()
c = constant_op.constant(["a", "b"])
@@ -45,7 +45,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNext(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list)
session.run([variables.global_variables_initializer()])
self.assertEqual(b"a", session.run(elem))
@@ -65,7 +65,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNextLimitEpochs(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
session.run([
variables.local_variables_initializer(),
@@ -75,7 +75,7 @@ class InputPipelineOpsTest(test.TestCase):
def testSeekNextLimitEpochsThree(self):
string_list = ["a", "b", "c"]
- with self.test_session() as session:
+ with self.cached_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
session.run([
variables.local_variables_initializer(),
diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
index 621911876f..08ebcdb544 100644
--- a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
+++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py
@@ -54,7 +54,7 @@ class KafkaDatasetTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from topic 0.
sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1})
for i in range(5):
diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
index a1624614d1..7129f09e8b 100644
--- a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
+++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py
@@ -17,15 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
from tensorflow.contrib.kafka.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.contrib.kafka.python.ops import kafka_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class KafkaDataset(Dataset):
+class KafkaDataset(dataset_ops.DatasetSource):
"""A Kafka Dataset that consumes the message.
"""
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
index 72507539f8..4d5cc24ce0 100644
--- a/tensorflow/contrib/kernel_methods/python/losses_test.py
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -32,7 +32,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLogitsShape(self):
"""An error is raised when logits have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2,))
labels = constant_op.constant([0, 1])
with self.assertRaises(ValueError):
@@ -40,7 +40,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLabelsShape(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(1, 1, 2))
with self.assertRaises(ValueError):
@@ -48,7 +48,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidWeightsShape(self):
"""An error is raised when weights have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2,))
weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1))
@@ -57,7 +57,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInvalidLabelsDtype(self):
"""An error is raised when labels have invalid shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], dtype=dtypes.float32)
with self.assertRaises(ValueError):
@@ -65,7 +65,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNoneWeightRaisesValueError(self):
"""An error is raised when weights are None."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0])
with self.assertRaises(ValueError):
@@ -73,7 +73,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInconsistentLabelsAndWeightsShapesSameRank(self):
"""Error raised when weights and labels have same ranks, different sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1))
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
weights = constant_op.constant([1.1, 2.0], shape=(2, 1))
@@ -82,7 +82,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
"""Error raised when weights and labels have different ranks and sizes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
labels = constant_op.constant([1, 0], shape=(2, 1))
weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,))
@@ -91,7 +91,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testOutOfRangeLabels(self):
"""An error is raised when labels are not in [0, num_classes)."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([1, 0, 4])
@@ -101,7 +101,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testZeroLossInt32Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
[0.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
@@ -110,7 +110,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testZeroLossInt64Labels(self):
"""Loss is 0 if true class logits sufficiently higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
[-0.5, 0.8, -1.0]])
labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
@@ -130,7 +130,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
]
for batch_size, num_classes in logits_shapes:
- with self.test_session():
+ with self.cached_session():
logits = array_ops.placeholder(
dtypes.float32, shape=(batch_size, num_classes))
labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
@@ -140,7 +140,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testCorrectPredictionsSomeClassesInsideMargin(self):
"""Loss is > 0 even if true class logits are higher than other classes."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
[1.5, 1.8, -1.0]])
labels = constant_op.constant([0, 2, 1])
@@ -150,7 +150,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictions(self):
"""Loss is >0 when an incorrect class has higher logits than true class."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
[0.5, -1.8, 2.0]])
labels = constant_op.constant([1, 0, 2])
@@ -162,7 +162,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictionsColumnLabels(self):
"""Same as above but labels is a rank-2 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -174,7 +174,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testIncorrectPredictionsZeroWeights(self):
"""Loss is 0 when all weights are missing even if predictions are wrong."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -185,7 +185,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWithPythonScalarWeights(self):
"""Weighted loss is correctly computed when weights is a python scalar."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -195,7 +195,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWithScalarTensorWeights(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -205,7 +205,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0]])
labels = constant_op.constant([1, 0, 2], shape=(3, 1))
@@ -216,7 +216,7 @@ class SparseMulticlassHingeLossTest(test.TestCase):
def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
"""Weighted loss is correctly computed when weights is a rank-0 tensor."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
[0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
labels = constant_op.constant([1, 0, 2, 1])
diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
index 2ff4d41d75..bad0a596a7 100644
--- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
+++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py
@@ -58,7 +58,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
def testInvalidInputShape(self):
x = constant_op.constant([[2.0, 1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10)
with self.assertRaisesWithPredicateMatch(
dense_kernel_mapper.InvalidShapeError,
@@ -70,7 +70,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0],
[4.0, -2.0, -1.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 10, 1.0)
mapped_x1 = rffm.map(x1)
mapped_x2 = rffm.map(x2)
@@ -80,7 +80,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
def testSameOmegaReused(self):
x = constant_op.constant([[2.0, 1.0, 0.0]])
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(3, 100)
mapped_x = rffm.map(x)
mapped_x_copy = rffm.map(x)
@@ -93,7 +93,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm1 = RandomFourierFeatureMapper(3, 100, stddev)
@@ -113,7 +113,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
y = constant_op.constant([[1.0, -1.0, 2.0]])
stddev = 3.0
- with self.test_session():
+ with self.cached_session():
# The mapped dimension is fairly small, so the kernel approximation is
# very rough.
rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0)
@@ -139,7 +139,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase):
normalized_points = [nn.l2_normalize(point, dim=1) for point in points]
total_absolute_error = 0.0
- with self.test_session():
+ with self.cached_session():
rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0)
# Cache mappings so that they are not computed multiple times.
cached_mappings = dict((point, rffm.map(point))
diff --git a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
index 7289b45c50..bf89922318 100644
--- a/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
+++ b/tensorflow/contrib/kinesis/python/kernel_tests/kinesis_test.py
@@ -64,7 +64,7 @@ class KinesisDatasetTest(test.TestCase):
init_batch_op = iterator.make_initializer(batch_dataset)
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from shard 0 of stream 1.
sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
for i in range(10):
@@ -108,7 +108,7 @@ class KinesisDatasetTest(test.TestCase):
get_next = iterator.get_next()
data = list()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Basic test: read from shard 0 of stream 2.
sess.run(
init_op, feed_dict={
diff --git a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
index ca2df95ba4..75806dbbeb 100644
--- a/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
+++ b/tensorflow/contrib/kinesis/python/ops/kinesis_dataset_ops.py
@@ -17,15 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
+from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-class KinesisDataset(Dataset):
+class KinesisDataset(dataset_ops.DatasetSource):
"""A Kinesis Dataset that consumes the message.
Kinesis is a managed service provided by AWS for data streaming.
diff --git a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
index 28ddaa69a1..155d06a08e 100644
--- a/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
+++ b/tensorflow/contrib/layers/python/kernel_tests/sparse_feature_cross_op_test.py
@@ -45,7 +45,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_dense(self):
@@ -66,7 +66,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_mixed_string_sparse(self):
@@ -80,7 +80,7 @@ class SparseCrossOpTest(test.TestCase):
'333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', '55555_X_batch2-FC2-F1',
'55555_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_mixed_string_dense(self):
@@ -99,7 +99,7 @@ class SparseCrossOpTest(test.TestCase):
'55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2',
'999999_X_batch2-FC2-F1', '999999_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_sparse_cross_dense(self):
@@ -117,7 +117,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_sparse_input(self):
@@ -133,7 +133,7 @@ class SparseCrossOpTest(test.TestCase):
'333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2',
'5555_X_batch2-FC2-F1', '5555_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_permutation_3x3x3(self):
@@ -176,7 +176,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F2',
'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F3'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_permutation_3x1x2(self):
@@ -196,7 +196,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1',
'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_large_batch(self):
@@ -229,7 +229,7 @@ class SparseCrossOpTest(test.TestCase):
])
expected_out = self._sparse_tensor(col_out)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_one_column_empty(self):
@@ -242,7 +242,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([], 1),
self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_empty(sess.run(op))
def test_some_columns_empty(self):
@@ -261,7 +261,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1',
'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2'
]], 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_all_columns_empty(self):
@@ -273,7 +273,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([]), self._sparse_tensor([]),
self._sparse_tensor([])
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_empty(sess.run(op))
def test_hashed_output_zero_bucket(self):
@@ -288,7 +288,7 @@ class SparseCrossOpTest(test.TestCase):
hashed_output=True)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[3735511728867393167]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output_zero_bucket_v2(self):
@@ -304,7 +304,7 @@ class SparseCrossOpTest(test.TestCase):
hash_key=layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[1971693436396284976]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
# TODO(sibyl-Aix6ihai): Add benchmark to compare Hashed vs Non-hashed.
@@ -321,7 +321,7 @@ class SparseCrossOpTest(test.TestCase):
num_buckets=100)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[74]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output_v2(self):
@@ -338,7 +338,7 @@ class SparseCrossOpTest(test.TestCase):
hash_key=layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[83]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output_v1_has_collision(self):
@@ -384,7 +384,7 @@ class SparseCrossOpTest(test.TestCase):
],
hashed_output=True,
num_buckets=1000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
out = sess.run(op)
self.assertEqual(6, len(out.values))
self.assertAllEqual([[0, i] for i in range(6)], out.indices)
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index 60e1d85ea9..17ee8c0733 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -112,9 +112,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
dtype = sparse_weights.dtype if sparse_weights is not None else None
if isinstance(embedding_weights, variables.PartitionedVariable):
embedding_weights = list(embedding_weights)
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ if not isinstance(embedding_weights[0],
+ resource_variable_ops.ResourceVariable):
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
contrib_tensor_util.assert_same_float_dtype(embedding_weights +
[sparse_weights])
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index 69d927e1b3..2fdcd849b0 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -21,8 +21,6 @@ from __future__ import print_function
import six
from tensorflow.contrib import framework as contrib_framework
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -433,12 +431,11 @@ def _multiply_gradients(grads_and_vars, gradient_multipliers):
if (grad is not None and
(var in gradient_multipliers or var.name in gradient_multipliers)):
key = var if var in gradient_multipliers else var.name
- multiplier = constant_op.constant(
- gradient_multipliers[key], dtype=dtypes.float32)
+ multiplier = gradient_multipliers[key]
if isinstance(grad, ops.IndexedSlices):
grad_values = grad.values * multiplier
grad = ops.IndexedSlices(grad_values, grad.indices, grad.dense_shape)
else:
- grad *= multiplier
+ grad *= math_ops.cast(multiplier, grad.dtype)
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 29dede2a49..b4d1239e76 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -250,6 +250,42 @@ class OptimizersTest(test.TestCase):
self.assertAlmostEqual(var_value, 6.5, 4)
self.assertEqual(global_step_value, 1)
+ def testGradientMultiplyInt32Tensor(self):
+ with self.cached_session() as session:
+ x, var, loss, global_step = _setup_model()
+ v = array_ops.placeholder(dtypes.float32, [])
+ train = optimizers_lib.optimize_loss(
+ loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer="SGD",
+ gradient_multipliers={var: v})
+ variables.global_variables_initializer().run()
+ session.run(train, feed_dict={x: 5, v: 7.})
+ var_value, global_step_value = session.run([var, global_step])
+ # var(0) = 10, x = 5, var(0)/dx = 5,
+ # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx
+ self.assertAlmostEqual(var_value, 6.5, 4)
+ self.assertEqual(global_step_value, 1)
+
+ def testGradientMultiplyInt64Tensor(self):
+ with self.cached_session() as session:
+ x, var, loss, global_step = _setup_model()
+ v = array_ops.placeholder(dtypes.float64, [])
+ train = optimizers_lib.optimize_loss(
+ loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer="SGD",
+ gradient_multipliers={var: v})
+ variables.global_variables_initializer().run()
+ session.run(train, feed_dict={x: 5, v: 7.})
+ var_value, global_step_value = session.run([var, global_step])
+ # var(0) = 10, x = 5, var(0)/dx = 5,
+ # var(1) = var(0) - learning_rate * gradient_multiplier * var(0)/dx
+ self.assertAlmostEqual(var_value, 6.5, 4)
+ self.assertEqual(global_step_value, 1)
+
def testIgnoreVariablesWithNoGradients(self):
_, _, loss, global_step = _setup_model()
diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py
index 69bb6be814..8a6b4f68a8 100644
--- a/tensorflow/contrib/layers/python/layers/target_column.py
+++ b/tensorflow/contrib/layers/python/layers/target_column.py
@@ -396,7 +396,7 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn):
def _mean_squared_loss(logits, target):
# To prevent broadcasting inside "-".
if len(target.get_shape()) == 1:
- target = array_ops.expand_dims(target, dim=[1])
+ target = array_ops.expand_dims(target, axis=1)
logits.get_shape().assert_is_compatible_with(target.get_shape())
return math_ops.square(logits - math_ops.to_float(target))
@@ -405,7 +405,7 @@ def _mean_squared_loss(logits, target):
def _log_loss_with_two_classes(logits, target):
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
if len(target.get_shape()) == 1:
- target = array_ops.expand_dims(target, dim=[1])
+ target = array_ops.expand_dims(target, axis=1)
loss_vec = nn.sigmoid_cross_entropy_with_logits(
labels=math_ops.to_float(target), logits=logits)
return loss_vec
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index ded93d4a7f..c6f79e00d5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -563,10 +563,10 @@ def _mean_squared_loss(labels, logits, weights=None):
labels = ops.convert_to_tensor(labels)
# To prevent broadcasting inside "-".
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, axis=(1,))
+ labels = array_ops.expand_dims(labels, axis=1)
# TODO(zakaria): make sure it does not recreate the broadcast bug.
if len(logits.get_shape()) == 1:
- logits = array_ops.expand_dims(logits, axis=(1,))
+ logits = array_ops.expand_dims(logits, axis=1)
logits.get_shape().assert_is_compatible_with(labels.get_shape())
loss = math_ops.square(logits - math_ops.to_float(labels), name=name)
return _compute_weighted_loss(loss, weights)
@@ -579,10 +579,10 @@ def _poisson_loss(labels, logits, weights=None):
labels = ops.convert_to_tensor(labels)
# To prevent broadcasting inside "-".
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, axis=(1,))
+ labels = array_ops.expand_dims(labels, axis=1)
# TODO(zakaria): make sure it does not recreate the broadcast bug.
if len(logits.get_shape()) == 1:
- logits = array_ops.expand_dims(logits, axis=(1,))
+ logits = array_ops.expand_dims(logits, axis=1)
logits.get_shape().assert_is_compatible_with(labels.get_shape())
loss = nn.log_poisson_loss(labels, logits, compute_full_loss=True,
name=name)
@@ -797,7 +797,7 @@ def _log_loss_with_two_classes(labels, logits, weights=None):
# TODO(ptucker): This will break for dynamic shapes.
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] labels.
if len(labels.get_shape()) == 1:
- labels = array_ops.expand_dims(labels, axis=(1,))
+ labels = array_ops.expand_dims(labels, axis=1)
loss = nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits,
name=name)
return _compute_weighted_loss(loss, weights)
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index d5c02124ac..33180b778a 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -234,7 +234,7 @@ class GraphActionsTest(test.TestCase):
self.assertTrue(test_ops.resource_initialized_op(handle).eval())
def test_infer_different_default_graph(self):
- with self.test_session():
+ with self.cached_session():
self._assert_ckpt(self._output_dir, False)
with ops.Graph().as_default():
in0, in1, out = self._build_inference_graph()
diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
index 2f33a2b74d..0e5ea6b9f7 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py
@@ -47,7 +47,7 @@ from tensorflow.python.training import adam
class Seq2SeqTest(test.TestCase):
def testRNNDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -65,7 +65,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testBasicRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -81,7 +81,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testTiedRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -98,7 +98,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testEmbeddingRNNDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -124,7 +124,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].h.shape)
def testEmbeddingRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -228,7 +228,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(res1, res3)
def testEmbeddingTiedRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -316,7 +316,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(res1, res3)
def testAttentionDecoder1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -341,7 +341,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testAttentionDecoder2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -367,7 +367,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testDynamicAttentionDecoder1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -391,7 +391,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testDynamicAttentionDecoder2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.GRUCell(2)
@@ -416,7 +416,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testAttentionDecoderStateIsTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda
@@ -448,7 +448,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0][1].h.shape)
def testDynamicAttentionDecoderStateIsTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda
@@ -479,7 +479,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0][1].h.shape)
def testEmbeddingAttentionDecoder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
@@ -513,7 +513,7 @@ class Seq2SeqTest(test.TestCase):
self.assertEqual((2, 2), res[0].shape)
def testEmbeddingAttentionSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -622,7 +622,7 @@ class Seq2SeqTest(test.TestCase):
# self.assertAllClose(res1, res3)
def testOne2ManyRNNSeq2Seq(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
enc_inp = [
@@ -712,7 +712,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(res1, res3)
def testSequenceLoss(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = [constant_op.constant(i + 0.5, shape=[2, 5]) for i in range(3)]
targets = [
constant_op.constant(
@@ -748,7 +748,7 @@ class Seq2SeqTest(test.TestCase):
self.assertAllClose(9.656628, res)
def testSequenceLossByExample(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_classes = 5
logits = [
constant_op.constant(
@@ -778,7 +778,7 @@ class Seq2SeqTest(test.TestCase):
# classes = 10
# buckets = [(4, 4), (8, 8)]
- # with self.test_session():
+ # with self.cached_session():
# # Here comes a sample Seq2Seq model using GRU cells.
# def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
# """Example sequence-to-sequence model that uses GRU cells."""
@@ -839,7 +839,7 @@ class Seq2SeqTest(test.TestCase):
random.seed(111)
np.random.seed(111)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# We use sampled softmax so we keep output projection separate.
w = variable_scope.get_variable("proj_w", [24, classes])
w_t = array_ops.transpose(w)
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
deleted file mode 100644
index 78b7970069..0000000000
--- a/tensorflow/contrib/linalg/BUILD
+++ /dev/null
@@ -1,44 +0,0 @@
-# Description:
-# Contains classes that provide access to common method of a [batch] matrix,
-# without the need to instantiate the matrix.
-# This allows for exploitation of structure, as well as a generic interface
-# suitable for iterative solvers.
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-package(default_visibility = ["//tensorflow:__subpackages__"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-py_library(
- name = "linalg_py",
- srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:util",
- "//tensorflow/python/ops/linalg",
- "@six_archive//:six",
- ],
-)
-
-cuda_py_test(
- name = "linear_operator_addition_test",
- size = "small",
- srcs = ["python/kernel_tests/linear_operator_addition_test.py"],
- additional_deps = [
- ":linalg_py",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform_test",
- ],
-)
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
deleted file mode 100644
index cbe4c03e4d..0000000000
--- a/tensorflow/contrib/linalg/__init__.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Linear algebra libraries.
-
-See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg)
-guide.
-
-@@LinearOperator
-@@LinearOperatorBlockDiag
-@@LinearOperatorCirculant
-@@LinearOperatorCirculant2D
-@@LinearOperatorCirculant3D
-@@LinearOperatorDiag
-@@LinearOperatorIdentity
-@@LinearOperatorScaledIdentity
-@@LinearOperatorFullMatrix
-@@LinearOperatorKronecker
-@@LinearOperatorLowerTriangular
-@@LinearOperatorLowRankUpdate
-@@LinearOperatorComposition
-@@add_operators
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
-
-from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
-from tensorflow.python.ops.linalg.linear_operator import *
-from tensorflow.python.ops.linalg.linear_operator_block_diag import *
-from tensorflow.python.ops.linalg.linear_operator_circulant import *
-from tensorflow.python.ops.linalg.linear_operator_composition import *
-from tensorflow.python.ops.linalg.linear_operator_diag import *
-from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
-from tensorflow.python.ops.linalg.linear_operator_identity import *
-from tensorflow.python.ops.linalg.linear_operator_kronecker import *
-from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
-from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
-
-# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
-
-from tensorflow.python.util.all_util import remove_undocumented
-
-remove_undocumented(__name__)
diff --git a/tensorflow/contrib/linalg/python/__init__.py b/tensorflow/contrib/linalg/python/__init__.py
deleted file mode 100644
index c5ca3a623f..0000000000
--- a/tensorflow/contrib/linalg/python/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""ops module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 1d2db1cec8..9ecf023e03 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -134,7 +134,7 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):
return examples_dict, variables_dict
-def make_variable_dict(max_age, max_gender, partitioned=False):
+def make_variable_dict(max_age, max_gender, num_shards=None, partitioned=False):
# TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from
# examples_dict.
partitioner = None
@@ -142,14 +142,15 @@ def make_variable_dict(max_age, max_gender, partitioned=False):
partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2,
axis=0)
with variable_scope.variable_scope(
- name_or_scope='variables',
+ name_or_scope=('variables/shard_{}'.format(num_shards)
+ if num_shards else 'variables'),
partitioner=partitioner):
- age_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_age + 1], dtype=dtypes.float32))
- gender_weights = variables_lib.Variable(
- array_ops.zeros(
- [max_gender + 1], dtype=dtypes.float32))
+ age_weights = variable_scope.get_variable(
+ name='age',
+ initializer=array_ops.zeros([max_age + 1], dtype=dtypes.float32))
+ gender_weights = variable_scope.get_variable(
+ name='gender',
+ initializer=array_ops.zeros([max_gender + 1], dtype=dtypes.float32))
return dict(
sparse_features_weights=[age_weights, gender_weights],
dense_features_weights=[])
@@ -242,7 +243,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -290,7 +291,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1, partitioned=True)
+ variables = make_variable_dict(1, 1, num_shards, partitioned=True)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -322,6 +323,68 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
self.assertAllClose(
0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+ def testSomePartitionedPrimals(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [0],
+ 'gender': [1]
+ }, 1),
+ ]
+ example_weights = [1.0, 1.0]
+ for num_shards in _SHARD_NUMBERS:
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ # Explicitly make age a [1]-shaped Variable (which cannot be
+ # partitioned), while making gender a PartitionedVariable.
+ age_weights = variables_lib.Variable(
+ array_ops.zeros([1], dtype=dtypes.float32))
+ with variable_scope.variable_scope(
+ name_or_scope=('variables/shard_{}'.format(num_shards)
+ if num_shards else 'variables'),
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0)):
+ gender_weights = variable_scope.get_variable(
+ name='gender',
+ initializer=array_ops.zeros([2], dtype=dtypes.float32))
+ variables = dict(
+ sparse_features_weights=[age_weights, gender_weights],
+ dense_features_weights=[])
+ options = dict(
+ symmetric_l2_regularization=1,
+ symmetric_l1_regularization=0,
+ num_table_shards=num_shards,
+ loss_type='logistic_loss')
+
+ lr = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+ unregularized_loss = lr.unregularized_loss(examples)
+ loss = lr.regularized_loss(examples)
+ predictions = lr.predictions(examples)
+ self.assertAllClose(0.693147, unregularized_loss.eval())
+ self.assertAllClose(0.693147, loss.eval())
+ train_op = lr.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ lr.update_weights(train_op).run()
+ # The high tolerance in unregularized_loss comparisons is due to the
+ # fact that it's possible to trade off unregularized_loss vs.
+ # regularization and still have a sum that is quite close to the
+ # optimal regularized_loss value. SDCA's duality gap only ensures that
+ # the regularized_loss is within 0.01 of optimal.
+ # 0.525457 is the optimal regularized_loss.
+ # 0.593014 is the unregularized_loss at that optimum.
+ self.assertAllClose(0.512591, unregularized_loss.eval(), atol=0.05)
+ self.assertAllClose(0.593014, loss.eval(), atol=0.01)
+ predicted_labels = get_binary_predictions_for_logistic(predictions)
+ self.assertAllEqual([0, 1], predicted_labels.eval())
+ self.assertAllClose(
+ 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)
+
def testSparseRandom(self):
dim = 20
num_examples = 1000
@@ -463,7 +526,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=0,
symmetric_l1_regularization=0,
@@ -521,7 +584,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
with self._single_threaded_test_session():
# Only use examples 0 and 2
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -561,7 +624,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -598,7 +661,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(3, 1)
+ variables = make_variable_dict(3, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -639,7 +702,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
@@ -679,7 +742,7 @@ class SdcaWithLogisticLossTest(SdcaModelTest):
for num_shards in _SHARD_NUMBERS:
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
- variables = make_variable_dict(1, 1)
+ variables = make_variable_dict(1, 1, num_shards)
options = dict(
symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 14f59a3f64..b98adf862b 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -400,14 +400,16 @@ class SdcaModel(object):
sparse_weights = []
sparse_indices = []
- # If we have partitioned variables, keep a few lists of Tensors around
- # that we need for the assign_add after the op call to
- # gen_sdca_ops.sdca_optimizer().
- num_partitions_by_var = []
- p_assignments_by_var = []
- gather_ids_by_var = []
- for w, i in zip(self._slots['unshrinked_sparse_features_weights'],
- sparse_feature_indices):
+ # If we have partitioned variables, keep a few dictionaries of Tensors
+ # around that we need for the assign_add after the op call to
+ # gen_sdca_ops.sdca_optimizer(). These are keyed because we may have a
+ # mix of partitioned and un-partitioned variables.
+ num_partitions_by_var = {}
+ p_assignments_by_var = {}
+ gather_ids_by_var = {}
+ for v_num, (w, i) in enumerate(
+ zip(self._slots['unshrinked_sparse_features_weights'],
+ sparse_feature_indices)):
# Append the sparse_indices (in full-variable space).
sparse_idx = math_ops.cast(
array_ops.unique(math_ops.cast(i, dtypes.int32))[0],
@@ -456,10 +458,10 @@ class SdcaModel(object):
gather_ids = data_flow_ops.dynamic_partition(new_ids,
p_assignments,
num_partitions)
- # Append these to the lists for use in the later update.
- num_partitions_by_var.append(num_partitions)
- p_assignments_by_var.append(p_assignments)
- gather_ids_by_var.append(gather_ids)
+ # Add these into the dictionaries for use in the later update.
+ num_partitions_by_var[v_num] = num_partitions
+ p_assignments_by_var[v_num] = p_assignments
+ gather_ids_by_var[v_num] = gather_ids
# Gather the weights from each partition.
partition_gathered_weights = []
diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md
index a676b705f1..a4b3d83efe 100644
--- a/tensorflow/contrib/lite/README.md
+++ b/tensorflow/contrib/lite/README.md
@@ -4,5 +4,5 @@ TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded
devices. It enables low-latency inference of on-device machine learning models
with a small binary size and fast performance supporting hardware acceleration.
-See the documentation: https://www.tensorflow.org/mobile/tflite/
-Documentation edits can be made here: [tensorflow/docs_src/mobile/tflite](../../docs_src/mobile/tflite)
+See the documentation: https://www.tensorflow.org/lite/
+Documentation edits can be made here: [tensorflow/contrib/lite/g3doc](./g3doc/)
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 9317e2bb6e..fc4d9b4f17 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -283,6 +283,7 @@ def generated_test_models():
"sparse_to_dense",
"split",
"sqrt",
+ "square",
"squeeze",
"strided_slice",
"strided_slice_1d_exhaustive",
@@ -293,34 +294,73 @@ def generated_test_models():
#"transpose_conv", # disabled due to b/111213074
"unpack",
"where",
+ "zeros_like",
]
-def gen_zip_test(name, test_name, **kwargs):
+def generated_test_conversion_modes():
+ """Returns a list of conversion modes."""
+
+ # TODO(nupurgarg): Add "pb2lite" when it's in open source. b/113614050.
+ return ["toco-extended", ""]
+
+def generated_test_models_all():
+ """Generates a list of all tests with the different converters.
+
+ Returns:
+ List of tuples representing (conversion mode, name of test).
+ """
+ conversion_modes = generated_test_conversion_modes()
+ tests = generated_test_models()
+ options = []
+ for conversion_mode in conversion_modes:
+ for test in tests:
+ if conversion_mode:
+ test += "_%s" % conversion_mode
+ options.append((conversion_mode, test))
+ return options
+
+def gen_zip_test(name, test_name, conversion_mode, **kwargs):
"""Generate a zipped-example test and its dependent zip files.
Args:
- name: Resulting cc_test target name
- test_name: Test targets this model. Comes from the list above.
- **kwargs: tf_cc_test kwargs.
+ name: str. Resulting cc_test target name
+ test_name: str. Test targets this model. Comes from the list above.
+ conversion_mode: str. Which conversion mode to run with. Comes from the
+ list above.
+ **kwargs: tf_cc_test kwargs
"""
+ toco = "//tensorflow/contrib/lite/toco:toco"
+ flags = ""
+ if conversion_mode:
+ # TODO(nupurgarg): Comment in when pb2lite is in open source. b/113614050.
+ # if conversion_mode == "pb2lite":
+ # toco = "//tensorflow/contrib/lite/experimental/pb2lite:pb2lite"
+ flags = "--ignore_toco_errors --run_with_extended"
+ kwargs["tags"].append("skip_already_failing")
+ kwargs["tags"].append("no_oss")
+ kwargs["tags"].append("notap")
+
gen_zipped_test_file(
name = "zip_%s" % test_name,
file = "%s.zip" % test_name,
+ toco = toco,
+ flags = flags,
)
tf_cc_test(name, **kwargs)
-def gen_zipped_test_file(name, file):
+def gen_zipped_test_file(name, file, toco, flags):
"""Generate a zip file of tests by using :generate_examples.
Args:
- name: Name of output. We will produce "`file`.files" as a target.
- file: The name of one of the generated_examples targets, e.g. "transpose"
+ name: str. Name of output. We will produce "`file`.files" as a target.
+ file: str. The name of one of the generated_examples targets, e.g. "transpose"
+ toco: str. Pathname of toco binary to run
+ flags: str. Any additional flags to include
"""
- toco = "//tensorflow/contrib/lite/toco:toco"
native.genrule(
name = file + ".files",
- cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco +
- " --zip_to_output " + file + " $(@D)"),
+ cmd = (("$(locations :generate_examples) --toco $(locations {0}) " +
+ " --zip_to_output {1} {2} $(@D)").format(toco, file, flags)),
outs = [file],
tools = [
":generate_examples",
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 9cf4bea73e..7809d114e2 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -117,6 +117,9 @@ typedef enum {
kTfLiteBuiltinReduceMin = 89,
kTfLiteBuiltinFloorDiv = 90,
kTfLiteBuiltinReduceAny = 91,
+ kTfLiteBuiltinSquare = 92,
+ kTfLiteBuiltinZerosLike = 93,
+ kTfLiteBuiltinFill = 94,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
index fa43e6a024..be9d551ee4 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data.h
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -25,6 +25,9 @@ extern "C" {
// TODO(aselle): Consider using "if this then that" for testing.
+// IMPORTANT: All new members of structs must be added at the end to ensure
+// backwards compatibility.
+
// Possible padding types (for convolutions)
typedef enum {
kTfLitePaddingUnknown = 0,
@@ -71,11 +74,15 @@ typedef struct {
} TfLitePoolParams;
typedef struct {
+ // Parameters for DepthwiseConv version 1 or above.
TfLitePadding padding;
int stride_width;
int stride_height;
int depth_multiplier;
TfLiteFusedActivation activation;
+ // Parameters for DepthwiseConv version 2 or above.
+ int dilation_width_factor;
+ int dilation_height_factor;
} TfLiteDepthwiseConvParams;
typedef struct {
diff --git a/tensorflow/contrib/lite/c/c_api_internal.c b/tensorflow/contrib/lite/c/c_api_internal.c
index 1846bad4b7..8a0c177b19 100644
--- a/tensorflow/contrib/lite/c/c_api_internal.c
+++ b/tensorflow/contrib/lite/c/c_api_internal.c
@@ -14,15 +14,29 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#ifndef TF_LITE_STATIC_MEMORY
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
+#endif // TF_LITE_STATIC_MEMORY
int TfLiteIntArrayGetSizeInBytes(int size) {
static TfLiteIntArray dummy;
return sizeof(dummy) + sizeof(dummy.data[0]) * size;
}
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) {
+ if (a == b) return 1;
+ if (a == NULL || b == NULL) return 0;
+ if (a->size != b->size) return 0;
+ int i = 0;
+ for (; i < a->size; i++)
+ if (a->data[i] != b->data[i]) return 0;
+ return 1;
+}
+
+#ifndef TF_LITE_STATIC_MEMORY
+
TfLiteIntArray* TfLiteIntArrayCreate(int size) {
TfLiteIntArray* ret =
(TfLiteIntArray*)malloc(TfLiteIntArrayGetSizeInBytes(size));
@@ -40,16 +54,6 @@ void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) {
printf("]\n");
}
-int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) {
- if (a == b) return 1;
- if (a == NULL || b == NULL) return 0;
- if (a->size != b->size) return 0;
- int i = 0;
- for (; i < a->size; i++)
- if (a->data[i] != b->data[i]) return 0;
- return 1;
-}
-
TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) {
if (!src) return NULL;
TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size);
@@ -102,3 +106,4 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
}
tensor->bytes = num_bytes;
}
+#endif // TF_LITE_STATIC_MEMORY
diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h
index 48df68a654..ee3dff6792 100644
--- a/tensorflow/contrib/lite/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/c/c_api_internal.h
@@ -146,7 +146,7 @@ void TfLiteIntArrayFree(TfLiteIntArray* v);
#define TF_LITE_ENSURE_OK(context, status) \
do { \
if ((status) != kTfLiteOk) { \
- return status; \
+ return kTfLiteError; \
} \
} while (0)
@@ -374,6 +374,11 @@ typedef struct TfLiteContext {
// WARNING: This is an experimental interface that is subject to change.
void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
TfLiteExternalContext*);
+
+ // Flag for allowing float16 precision for FP32 calculation.
+ // default: false.
+ // WARNING: This is an experimental API and subject to change.
+ bool allow_fp32_relax_to_fp16;
} TfLiteContext;
typedef struct _TfLiteRegistration {
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index 1420fbcdc6..e6900e0950 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -44,16 +44,6 @@ void FlatBufferIntVectorToArray(int max_size_of_buffer,
}
}
-// Allocate a structure using malloc, but make sure the structure is a POD
-// structure that doesn't require constructors to run. The reason we do this,
-// is that Interpreter's C extension part will take ownership so destructors
-// will not be run during deallocation.
-template <class T>
-T* MallocPOD() {
- static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
- return static_cast<T*>(malloc(sizeof(T)));
-}
-
} // namespace
TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
@@ -98,7 +88,8 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
// need to be released by calling `free`.`
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data) {
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data) {
auto parse_padding = [](Padding padding) {
switch (padding) {
case Padding_SAME:
@@ -150,7 +141,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = nullptr;
switch (op_type) {
case BuiltinOperator_CONV_2D: {
- TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ TfLiteConvParams* params = allocator->AllocatePOD<TfLiteConvParams>();
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
@@ -165,7 +156,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_CAST: {
- TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ TfLiteCastParams* params = allocator->AllocatePOD<TfLiteCastParams>();
if (auto* schema_params = op->builtin_options_as_CastOptions()) {
auto in_status =
ConvertTensorType(schema_params->in_data_type(),
@@ -174,7 +165,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
ConvertTensorType(schema_params->out_data_type(),
&params->out_data_type, error_reporter);
if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
- free(params);
+ allocator->Deallocate(params);
return kTfLiteError;
}
}
@@ -183,7 +174,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_LSH_PROJECTION: {
TfLiteLSHProjectionParams* params =
- MallocPOD<TfLiteLSHProjectionParams>();
+ allocator->AllocatePOD<TfLiteLSHProjectionParams>();
if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
params->type = parseLSHProjectionType(lshParams->type());
}
@@ -193,7 +184,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_AVERAGE_POOL_2D:
case BuiltinOperator_MAX_POOL_2D:
case BuiltinOperator_L2_POOL_2D: {
- TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ TfLitePoolParams* params = allocator->AllocatePOD<TfLitePoolParams>();
if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
params->padding = parse_padding(pool_params->padding());
params->stride_width = pool_params->stride_w();
@@ -208,7 +199,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_DEPTHWISE_CONV_2D: {
TfLiteDepthwiseConvParams* params =
- MallocPOD<TfLiteDepthwiseConvParams>();
+ allocator->AllocatePOD<TfLiteDepthwiseConvParams>();
if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
params->padding = parse_padding(conv_params->padding());
params->stride_width = conv_params->stride_w();
@@ -216,12 +207,15 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->depth_multiplier = conv_params->depth_multiplier();
params->activation =
parse_activation(conv_params->fused_activation_function());
+
+ params->dilation_width_factor = conv_params->dilation_w_factor();
+ params->dilation_height_factor = conv_params->dilation_h_factor();
}
*builtin_data = reinterpret_cast<void*>(params);
break;
}
case BuiltinOperator_SVDF: {
- TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ TfLiteSVDFParams* params = allocator->AllocatePOD<TfLiteSVDFParams>();
if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
params->rank = svdf_params->rank();
params->activation =
@@ -232,7 +226,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ TfLiteSequenceRNNParams* params =
+ allocator->AllocatePOD<TfLiteSequenceRNNParams>();
if (auto* sequence_rnn_params =
op->builtin_options_as_SequenceRNNOptions()) {
params->activation =
@@ -243,7 +238,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RNN: {
- TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ TfLiteRNNParams* params = allocator->AllocatePOD<TfLiteRNNParams>();
if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
params->activation =
parse_activation(rnn_params->fused_activation_function());
@@ -253,7 +248,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
TfLiteEmbeddingLookupSparseParams* params =
- MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ allocator->AllocatePOD<TfLiteEmbeddingLookupSparseParams>();
if (auto* embedding_params =
op->builtin_options_as_EmbeddingLookupSparseOptions()) {
params->combiner = parseCombinerType(embedding_params->combiner());
@@ -263,7 +258,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_FULLY_CONNECTED: {
TfLiteFullyConnectedParams* params =
- MallocPOD<TfLiteFullyConnectedParams>();
+ allocator->AllocatePOD<TfLiteFullyConnectedParams>();
if (auto* fully_connected_params =
op->builtin_options_as_FullyConnectedOptions()) {
params->activation = parse_activation(
@@ -288,7 +283,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
// no-op.
break;
case BuiltinOperator_SOFTMAX: {
- TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ TfLiteSoftmaxParams* params =
+ allocator->AllocatePOD<TfLiteSoftmaxParams>();
if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
params->beta = softmax_params->beta();
}
@@ -297,7 +293,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_CONCATENATION: {
TfLiteConcatenationParams* params =
- MallocPOD<TfLiteConcatenationParams>();
+ allocator->AllocatePOD<TfLiteConcatenationParams>();
if (auto* concatenation_params =
op->builtin_options_as_ConcatenationOptions()) {
params->activation =
@@ -308,7 +304,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_MUL: {
- auto* params = MallocPOD<TfLiteMulParams>();
+ auto* params = allocator->AllocatePOD<TfLiteMulParams>();
if (auto* schema_params = op->builtin_options_as_MulOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -317,7 +313,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ADD: {
- auto* params = MallocPOD<TfLiteAddParams>();
+ auto* params = allocator->AllocatePOD<TfLiteAddParams>();
if (auto* schema_params = op->builtin_options_as_AddOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -326,7 +322,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_DIV: {
- auto* params = MallocPOD<TfLiteDivParams>();
+ auto* params = allocator->AllocatePOD<TfLiteDivParams>();
if (auto* schema_params = op->builtin_options_as_DivOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -335,7 +331,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SUB: {
- auto* params = MallocPOD<TfLiteSubParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSubParams>();
if (auto* schema_params = op->builtin_options_as_SubOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -344,7 +340,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_L2_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteL2NormParams>();
+ auto* params = allocator->AllocatePOD<TfLiteL2NormParams>();
if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
params->activation =
parse_activation(schema_params->fused_activation_function());
@@ -353,7 +349,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ auto* params = allocator->AllocatePOD<TfLiteLocalResponseNormParams>();
if (auto* schema_params =
op->builtin_options_as_LocalResponseNormalizationOptions()) {
params->radius = schema_params->radius();
@@ -367,7 +363,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ TfLiteLSTMParams* params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
params->activation =
parse_activation(lstm_params->fused_activation_function());
@@ -386,7 +382,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RESIZE_BILINEAR: {
- auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ auto* params = allocator->AllocatePOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =
op->builtin_options_as_ResizeBilinearOptions()) {
params->align_corners = schema_params->align_corners();
@@ -395,7 +391,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_RESHAPE: {
- auto* params = MallocPOD<TfLiteReshapeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteReshapeParams>();
if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
auto* new_shape = schema_params->new_shape();
FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
@@ -406,7 +402,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SKIP_GRAM: {
- TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ TfLiteSkipGramParams* params =
+ allocator->AllocatePOD<TfLiteSkipGramParams>();
if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
params->ngram_size = skip_gram_params->ngram_size();
params->max_skip_size = skip_gram_params->max_skip_size();
@@ -416,7 +413,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SPACE_TO_DEPTH: {
- auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSpaceToDepthParams>();
if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
params->block_size = schema_params->block_size();
}
@@ -424,7 +421,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_GATHER: {
- TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
+ TfLiteGatherParams* params = allocator->AllocatePOD<TfLiteGatherParams>();
params->axis = 0;
if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
params->axis = gather_params->axis();
@@ -439,7 +436,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_REDUCE_PROD:
case BuiltinOperator_REDUCE_ANY:
case BuiltinOperator_SUM: {
- auto* params = MallocPOD<TfLiteReducerParams>();
+ auto* params = allocator->AllocatePOD<TfLiteReducerParams>();
if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
params->keep_dims = schema_params->keep_dims();
}
@@ -447,7 +444,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SPLIT: {
- auto* params = MallocPOD<TfLiteSplitParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSplitParams>();
if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
params->num_splits = schema_params->num_splits();
}
@@ -455,7 +452,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SQUEEZE: {
- auto* params = MallocPOD<TfLiteSqueezeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteSqueezeParams>();
if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
const auto& squeeze_dims = schema_params->squeeze_dims();
FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
@@ -466,7 +463,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_STRIDED_SLICE: {
- auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ auto* params = allocator->AllocatePOD<TfLiteStridedSliceParams>();
if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
params->begin_mask = schema_params->begin_mask();
params->end_mask = schema_params->end_mask();
@@ -478,7 +475,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ARG_MAX: {
- auto* params = MallocPOD<TfLiteArgMaxParams>();
+ auto* params = allocator->AllocatePOD<TfLiteArgMaxParams>();
if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
ConvertTensorType(schema_params->output_type(), &params->output_type,
error_reporter);
@@ -487,7 +484,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ARG_MIN: {
- auto* params = MallocPOD<TfLiteArgMinParams>();
+ auto* params = allocator->AllocatePOD<TfLiteArgMinParams>();
if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
ConvertTensorType(schema_params->output_type(), &params->output_type,
error_reporter);
@@ -497,7 +494,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_TRANSPOSE_CONV: {
TfLiteTransposeConvParams* params =
- MallocPOD<TfLiteTransposeConvParams>();
+ allocator->AllocatePOD<TfLiteTransposeConvParams>();
if (auto* transpose_conv_params =
op->builtin_options_as_TransposeConvOptions()) {
params->padding = parse_padding(transpose_conv_params->padding());
@@ -509,7 +506,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
}
case BuiltinOperator_SPARSE_TO_DENSE: {
TfLiteSparseToDenseParams* params =
- MallocPOD<TfLiteSparseToDenseParams>();
+ allocator->AllocatePOD<TfLiteSparseToDenseParams>();
if (auto* sparse_to_dense_params =
op->builtin_options_as_SparseToDenseOptions()) {
params->validate_indices = sparse_to_dense_params->validate_indices();
@@ -518,7 +515,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_SHAPE: {
- auto* params = MallocPOD<TfLiteShapeParams>();
+ auto* params = allocator->AllocatePOD<TfLiteShapeParams>();
if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
ConvertTensorType(schema_params->out_type(), &params->out_type,
error_reporter);
@@ -527,7 +524,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_PACK: {
- TfLitePackParams* params = MallocPOD<TfLitePackParams>();
+ TfLitePackParams* params = allocator->AllocatePOD<TfLitePackParams>();
if (auto* pack_params = op->builtin_options_as_PackOptions()) {
params->values_count = pack_params->values_count();
params->axis = pack_params->axis();
@@ -541,7 +538,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
return kTfLiteError;
}
case BuiltinOperator_FAKE_QUANT: {
- auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ auto* params = allocator->AllocatePOD<TfLiteFakeQuantParams>();
if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
params->min = schema_params->min();
params->max = schema_params->max();
@@ -552,7 +549,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_ONE_HOT: {
- auto* params = MallocPOD<TfLiteOneHotParams>();
+ auto* params = allocator->AllocatePOD<TfLiteOneHotParams>();
if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
params->axis = schema_params->axis();
}
@@ -560,7 +557,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
break;
}
case BuiltinOperator_UNPACK: {
- TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+ TfLiteUnpackParams* params = allocator->AllocatePOD<TfLiteUnpackParams>();
if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
params->num = unpack_params->num();
params->axis = unpack_params->axis();
@@ -614,6 +611,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_LOGICAL_AND:
case BuiltinOperator_LOGICAL_NOT:
case BuiltinOperator_FLOOR_DIV:
+ case BuiltinOperator_SQUARE:
+ case BuiltinOperator_ZEROS_LIKE:
+ case BuiltinOperator_FILL:
break;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
index 4dec6f9cfc..c770e627fd 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
@@ -26,6 +26,25 @@ limitations under the License.
namespace tflite {
+// Interface class for builtin data allocations.
+class BuiltinDataAllocator {
+ public:
+ virtual void* Allocate(size_t size) = 0;
+ virtual void Deallocate(void* data) = 0;
+
+ // Allocate a structure, but make sure it is a POD structure that doesn't
+ // require constructors to run. The reason we do this, is that Interpreter's C
+ // extension part will take ownership so destructors will not be run during
+ // deallocation.
+ template <typename T>
+ T* AllocatePOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(this->Allocate(sizeof(T)));
+ }
+
+ virtual ~BuiltinDataAllocator() {}
+};
+
// Parse the appropriate data out of the op.
//
// This handles builtin data explicitly as there are flatbuffer schemas.
@@ -36,7 +55,8 @@ namespace tflite {
// function's responsibility to free it.
// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data);
+ ErrorReporter* error_reporter,
+ BuiltinDataAllocator* allocator, void** builtin_data);
// Converts the tensor data type used in the flat buffer to the representation
// used by the runtime.
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
index b12bdf43b2..8ae94e1d33 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
@@ -39,11 +39,31 @@ class MockErrorReporter : public ErrorReporter {
int buffer_size_;
};
+// Used to determine how the op data parsing function creates its working space.
+class MockDataAllocator : public BuiltinDataAllocator {
+ public:
+ MockDataAllocator() : is_allocated_(false) {}
+ void* Allocate(size_t size) override {
+ EXPECT_FALSE(is_allocated_);
+ const int max_size = kBufferSize;
+ EXPECT_LE(size, max_size);
+ is_allocated_ = true;
+ return buffer_;
+ }
+ void Deallocate(void* data) override { is_allocated_ = false; }
+
+ private:
+ static constexpr int kBufferSize = 1024;
+ char buffer_[kBufferSize];
+ bool is_allocated_;
+};
+
} // namespace
TEST(FlatbufferConversions, TestParseOpDataConv) {
MockErrorReporter mock_reporter;
ErrorReporter* reporter = &mock_reporter;
+ MockDataAllocator mock_allocator;
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<void> conv_options =
@@ -58,7 +78,7 @@ TEST(FlatbufferConversions, TestParseOpDataConv) {
const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
void* output_data = nullptr;
EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
- &output_data));
+ &mock_allocator, &output_data));
EXPECT_NE(nullptr, output_data);
TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
EXPECT_EQ(kTfLitePaddingSame, params->padding);
@@ -67,12 +87,12 @@ TEST(FlatbufferConversions, TestParseOpDataConv) {
EXPECT_EQ(kTfLiteActRelu, params->activation);
EXPECT_EQ(3, params->dilation_width_factor);
EXPECT_EQ(4, params->dilation_height_factor);
- free(output_data);
}
TEST(FlatbufferConversions, TestParseOpDataCustom) {
MockErrorReporter mock_reporter;
ErrorReporter* reporter = &mock_reporter;
+ MockDataAllocator mock_allocator;
flatbuffers::FlatBufferBuilder builder;
flatbuffers::Offset<void> null_options;
@@ -84,7 +104,7 @@ TEST(FlatbufferConversions, TestParseOpDataCustom) {
const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer);
void* output_data = nullptr;
EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter,
- &output_data));
+ &mock_allocator, &output_data));
EXPECT_EQ(nullptr, output_data);
}
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index 984f8bbc98..43ec5d53b8 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -157,6 +157,34 @@ TEST_F(DelegateTest, OnlyTFLite) {
ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
}
+TEST_F(DelegateTest, MultipleInvokeCalls) {
+ // Call Invoke() multiple times on the same model.
+ AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3});
+ AddTfLiteMulOp({0, 1}, {2});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2, 1});
+ SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f});
+ SetShape(1, {2, 2, 1});
+ SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
+ ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f));
+
+ SetShape(0, {2, 2, 1});
+ SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f});
+ SetShape(1, {2, 2, 1});
+ SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1));
+ ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f));
+}
+
TEST_F(DelegateTest, MultipleInterpretersSameDelegate) {
// Build a graph, configure the delegate and set inputs.
{
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
index 274c3c082a..48a2f56baf 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc
index 8584999ace..d47be761fb 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
#include "absl/memory/memory.h"
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index e3eebac4da..d85e576284 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -518,7 +518,7 @@ class NNAPIDelegateKernel {
}
break;
case kTfLiteBuiltinReshape:
- if (version == 1) {
+ if (version == 1 && node->inputs->size == 2) {
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
return ANEURALNETWORKS_RESHAPE;
@@ -1115,6 +1115,14 @@ class NNAPIDelegateKernel {
CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs(
nn_model_.get(), inputs.size(), inputs.data(),
outputs.size(), outputs.data()));
+
+ // Set relaxed computation mode for fp32 if possible.
+ if (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) {
+ CHECK_NN(context,
+ ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+ nn_model_.get(), context->allow_fp32_relax_to_fp16));
+ }
+
// Finalize the model
CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get()));
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index 4b01aefd6a..9626c54c74 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -40,13 +40,15 @@ class FloatAddOpModel : public SingleOpModelWithNNAPI {
public:
FloatAddOpModel(const TensorData& input1, const TensorData& input2,
const TensorData& output,
- ActivationFunctionType activation_type) {
+ ActivationFunctionType activation_type,
+ bool allow_fp32_relax_to_fp16 = false) {
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
CreateAddOptions(builder_, activation_type).Union());
- BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)},
+ allow_fp32_relax_to_fp16);
}
int input1() { return input1_; }
@@ -71,6 +73,19 @@ TEST(NNAPIDelegate, AddWithNoActivation) {
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
}
+// Do a test with the NN API using no activation.
+// The test allows computing FP32 with FP16 precision. In this particular case,
+// calculating in FP32 or FP16 should produce the same results.
+TEST(NNAPIDelegate, AddWithNoActivationRelaxed) {
+ FloatAddOpModel m(
+ {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE, true);
+ m.PopulateTensor<float>(m.input1(), {-2.0, -1.0, 1.0, 2.0});
+ m.PopulateTensor<float>(m.input2(), {1.0, 2.0, 3.0, 4.0});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 1.0, 4.0, 6.0}));
+}
+
// Do a test with the NN api with relu.
TEST(NNAPIDelegate, AddWithRelu) {
FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
diff --git a/tensorflow/contrib/lite/examples/android/app/README.md b/tensorflow/contrib/lite/examples/android/app/README.md
index cbdeeac879..7347147f99 100644
--- a/tensorflow/contrib/lite/examples/android/app/README.md
+++ b/tensorflow/contrib/lite/examples/android/app/README.md
@@ -1,8 +1,43 @@
# TF Lite Android App Example
+A simple Android example that demonstrates image classification and object
+detection using the camera, as well as speech recognition using the microphone.
+
+## Building in Android Studio with TensorFlow Lite AAR from JCenter.
+The build.gradle is configured to use TensorFlow Lite's nightly build.
+
+If you see a build error related to compatibility with Tensorflow Lite's Java
+API (example: method X is undefined for type Interpreter), there has likely been
+a backwards compatible change to the API. You will need to pull new app code
+that's compatible with the nightly build and may need to first wait a few days
+for our external and internal code to merge.
+
## Building from Source with Bazel
-1. Install [Bazel](https://docs.bazel.build/versions/master/install.html), the Android NDK and SDK. The recommended versions are specified on this [webpage](https://www.tensorflow.org/mobile/tflite/demo_android#build_tensorflow_lite_and_the_demo_app_from_source).
+1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel):
+
+ 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites).
+ It's easiest with Android Studio.
+
+ - You'll need at least SDK version 23.
+ - Make sure to install the latest version of Bazel. Some distributions
+ ship with Bazel 0.5.4, which is too old.
+ - Bazel requires Android Build Tools `26.0.1` or higher.
+ - You also need to install the Android Support Repository, available
+ through Android Studio under `Android SDK Manager -> SDK Tools ->
+ Android Support Repository`.
+
+ 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace)
+ to add SDK and NDK targets.
+
+ NOTE: As long as you have the SDK and NDK installed, the `./configure`
+ script will create these rules for you. Answer "Yes" when the script asks
+ to automatically configure the `./WORKSPACE`.
+
+ - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
+ you have installed.
+ - By default, Android Studio will install the SDK to `~/Android/Sdk` and
+ the NDK to `~/Android/Sdk/ndk-bundle`.
2. Build this demo app with Bazel. The demo needs C++11. We configure the fat_apk_cpu flag to package support for 4 hardware variants. You may replace it with --config=android_arm64 on a 64-bit device and --config=android_arm for 32-bit device:
diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
index 6fdcf78b69..21ad39a6bf 100644
--- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
+++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h
@@ -80,8 +80,7 @@ void resize(T* out, uint8_t* in, int image_height, int image_width,
interpreter->Invoke();
auto output = interpreter->typed_tensor<float>(2);
- auto output_number_of_pixels =
- wanted_height * wanted_height * wanted_channels;
+ auto output_number_of_pixels = wanted_height * wanted_width * wanted_channels;
for (int i = 0; i < output_number_of_pixels; i++) {
if (s->input_floating)
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index ea4a543252..52e71619de 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -1,5 +1,12 @@
package(default_visibility = ["//visibility:private"])
+package_group(
+ name = "experimental",
+ packages = [
+ "//tensorflow/contrib/lite/experimental/...",
+ ],
+)
+
licenses(["notice"]) # Apache 2.0
load(
@@ -51,6 +58,9 @@ cc_library(
srcs = ["c_api.cc"],
hdrs = ["c_api.h"],
copts = tflite_copts(),
+ visibility = [
+ ":experimental",
+ ],
deps = [
":c_api_internal",
"//tensorflow/contrib/lite:context",
@@ -68,6 +78,7 @@ cc_library(
deps = [
":c_api",
":c_api_internal",
+ "//tensorflow/contrib/lite:kernel_api",
],
)
@@ -93,6 +104,7 @@ cc_test(
deps = [
":c_api",
":c_api_experimental",
+ "//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index c589cf71ea..9c29f9d8b9 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
@@ -26,6 +27,26 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
+namespace {
+class CallbackErrorReporter : public tflite::ErrorReporter {
+ public:
+ using ErrorCallback = void (*)(void* user_data, const char* format,
+ va_list args);
+
+ CallbackErrorReporter(ErrorCallback callback, void* user_data)
+ : callback_(callback), user_data_(user_data) {}
+
+ int Report(const char* format, va_list args) override {
+ callback_(user_data_, format, args);
+ return 0;
+ }
+
+ private:
+ ErrorCallback callback_;
+ void* user_data_;
+};
+} // namespace
+
// LINT.IfChange
TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
@@ -56,14 +77,38 @@ void TFL_InterpreterOptionsSetNumThreads(TFL_InterpreterOptions* options,
options->num_threads = num_threads;
}
+TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetErrorReporter(
+ TFL_InterpreterOptions* options,
+ void (*reporter)(void* user_data, const char* format, va_list args),
+ void* user_data) {
+ options->error_reporter = reporter;
+ options->error_reporter_user_data = user_data;
+}
+
TFL_Interpreter* TFL_NewInterpreter(
const TFL_Model* model, const TFL_InterpreterOptions* optional_options) {
if (!model || !model->impl) {
return nullptr;
}
+ std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
+ if (optional_options && optional_options->error_reporter != nullptr) {
+ optional_error_reporter.reset(
+ new CallbackErrorReporter(optional_options->error_reporter,
+ optional_options->error_reporter_user_data));
+ }
+
+ // TODO(b/111881878): Allow use of C API without pulling in all builtin ops.
tflite::ops::builtin::BuiltinOpResolver resolver;
- tflite::InterpreterBuilder builder(*model->impl, resolver);
+ if (optional_options) {
+ resolver.AddAll(optional_options->op_resolver);
+ }
+ tflite::ErrorReporter* error_reporter = optional_error_reporter
+ ? optional_error_reporter.get()
+ : tflite::DefaultErrorReporter();
+ tflite::InterpreterBuilder builder(model->impl->GetModel(), resolver,
+ error_reporter);
+
std::unique_ptr<tflite::Interpreter> interpreter;
if (builder(&interpreter) != kTfLiteOk) {
return nullptr;
@@ -76,7 +121,8 @@ TFL_Interpreter* TFL_NewInterpreter(
}
}
- return new TFL_Interpreter{model->impl, std::move(interpreter)};
+ return new TFL_Interpreter{model->impl, std::move(optional_error_reporter),
+ std::move(interpreter)};
}
void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index b429e76870..f52ab8f9ed 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_H_
+#include <stdarg.h>
#include <stdint.h>
// Eventually the various C APIs defined in context.h will be migrated into
@@ -52,8 +53,9 @@ limitations under the License.
extern "C" {
#endif // __cplusplus
-typedef TfLiteTensor TFL_Tensor;
+typedef TfLiteRegistration TFL_Registration;
typedef TfLiteStatus TFL_Status;
+typedef TfLiteTensor TFL_Tensor;
typedef TfLiteType TFL_Type;
// --------------------------------------------------------------------------
@@ -85,6 +87,17 @@ TFL_CAPI_EXPORT extern void TFL_DeleteInterpreterOptions(
TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetNumThreads(
TFL_InterpreterOptions* options, int32_t num_threads);
+// Sets a custom error reporter for interpreter execution.
+//
+// * `reporter` takes the provided `user_data` object, as well as a C-style
+// format string and arg list (see also vprintf).
+// * `user_data` is optional. If provided, it is owned by the client and must
+// remain valid for the duration of the interpreter lifetime.
+TFL_CAPI_EXPORT extern void TFL_InterpreterOptionsSetErrorReporter(
+ TFL_InterpreterOptions* options,
+ void (*reporter)(void* user_data, const char* format, va_list args),
+ void* user_data);
+
// --------------------------------------------------------------------------
// TFL_Interpreter provides inference from a provided model.
typedef struct TFL_Interpreter TFL_Interpreter;
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
index c4dbc55cbf..0f16595811 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.cc
@@ -26,6 +26,22 @@ TFL_Status TFL_InterpreterResetVariableTensorsToZero(
return interpreter->impl->ResetVariableTensorsToZero();
}
+void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
+ TFL_BuiltinOperator op,
+ const TFL_Registration* registration,
+ int32_t min_version,
+ int32_t max_version) {
+ options->op_resolver.AddBuiltin(static_cast<tflite::BuiltinOperator>(op),
+ registration, min_version, max_version);
+}
+
+void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options,
+ const char* name,
+ const TFL_Registration* registration,
+ int min_version, int max_version) {
+ options->op_resolver.AddCustom(name, registration, min_version, max_version);
+}
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
index b0ac258dcf..b8de7b9964 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental.h
@@ -15,16 +15,41 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_C_C_API_EXPERIMENTAL_H_
+#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
+typedef TfLiteBuiltinOperator TFL_BuiltinOperator;
+
// Resets all variable tensors to zero.
TFL_CAPI_EXPORT extern TFL_Status TFL_InterpreterResetVariableTensorsToZero(
TFL_Interpreter* interpreter);
+// Adds an op registration for a builtin operator.
+//
+// NOTE: The interpreter will make a copy of `registration` internally, so the
+// caller should ensure that its contents (function pointers, etc...) remain
+// valid for the duration of the interpreter's lifetime. A common practice is
+// making the provided TFL_Registration instance static.
+void TFL_InterpreterOptionsAddBuiltinOp(TFL_InterpreterOptions* options,
+ TFL_BuiltinOperator op,
+ const TFL_Registration* registration,
+ int min_version, int max_version);
+
+// Adds an op registration for a custom operator.
+//
+// NOTE: The interpreter will make a copy of `registration` internally, so the
+// caller should ensure that its contents (function pointers, etc...) remain
+// valid for the duration of the interpreter's lifetime. A common practice is
+// making the provided TFL_Registration instance static.
+void TFL_InterpreterOptionsAddCustomOp(TFL_InterpreterOptions* options,
+ const char* name,
+ const TFL_Registration* registration,
+ int min_version, int max_version);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
index db6e5251de..d86ad00d6d 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_experimental_test.cc
@@ -16,25 +16,40 @@ limitations under the License.
#include "tensorflow/contrib/lite/experimental/c/c_api_experimental.h"
#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/builtin_ops.h"
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace {
+TfLiteRegistration* GetDummyRegistration() {
+ static TfLiteRegistration registration = {
+ .init = nullptr,
+ .free = nullptr,
+ .prepare = nullptr,
+ .invoke = [](TfLiteContext*, TfLiteNode*) { return kTfLiteOk; },
+ };
+ return &registration;
+}
+
TEST(CApiExperimentalSimple, Smoke) {
TFL_Model* model = TFL_NewModelFromFile(
"tensorflow/contrib/lite/testdata/add.bin");
ASSERT_NE(model, nullptr);
- TFL_Interpreter* interpreter =
- TFL_NewInterpreter(model, /*optional_options=*/nullptr);
+ TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+ TFL_InterpreterOptionsAddBuiltinOp(options, kTfLiteBuiltinAdd,
+ GetDummyRegistration(), 1, 1);
+
+ TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
ASSERT_NE(interpreter, nullptr);
ASSERT_EQ(TFL_InterpreterAllocateTensors(interpreter), kTfLiteOk);
-
EXPECT_EQ(TFL_InterpreterResetVariableTensorsToZero(interpreter), kTfLiteOk);
+ EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteOk);
- TFL_DeleteModel(model);
TFL_DeleteInterpreter(interpreter);
+ TFL_DeleteInterpreterOptions(options);
+ TFL_DeleteModel(model);
}
} // namespace
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index 60c2e4e2cd..da3af3cad4 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -19,9 +19,13 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
// Internal structures used by the C API. These are likely to change and should
// not be depended on.
+//
+// NOTE: This header does not follow C conventions and does not define a C API.
+// It is effectively an (internal) implementation detail of the C API.
struct TFL_Model {
// Sharing is safe as FlatBufferModel is const.
@@ -33,12 +37,24 @@ struct TFL_InterpreterOptions {
kDefaultNumThreads = -1,
};
int num_threads = kDefaultNumThreads;
+
+ tflite::MutableOpResolver op_resolver;
+
+ void (*error_reporter)(void* user_data, const char* format,
+ va_list args) = nullptr;
+ void* error_reporter_user_data = nullptr;
};
struct TFL_Interpreter {
// Taking a reference to the (const) model data avoids lifetime-related issues
// and complexity with the TFL_Model's existence.
std::shared_ptr<const tflite::FlatBufferModel> model;
+
+ // The interpreter does not take ownership of the provided ErrorReporter
+ // instance, so we ensure its validity here. Note that the interpreter may use
+ // the reporter in its destructor, so it should be declared first.
+ std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
+
std::unique_ptr<tflite::Interpreter> impl;
};
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
index 649dac8d1a..48a3714ec3 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
@@ -85,6 +85,37 @@ TEST(CApiSimple, Smoke) {
TFL_DeleteInterpreter(interpreter);
}
+TEST(CApiSimple, ErrorReporter) {
+ TFL_Model* model = TFL_NewModelFromFile(
+ "tensorflow/contrib/lite/testdata/add.bin");
+ TFL_InterpreterOptions* options = TFL_NewInterpreterOptions();
+
+ // Install a custom error reporter into the interpreter by way of options.
+ tflite::TestErrorReporter reporter;
+ TFL_InterpreterOptionsSetErrorReporter(
+ options,
+ [](void* user_data, const char* format, va_list args) {
+ reinterpret_cast<tflite::TestErrorReporter*>(user_data)->Report(format,
+ args);
+ },
+ &reporter);
+ TFL_Interpreter* interpreter = TFL_NewInterpreter(model, options);
+
+ // The options/model can be deleted immediately after interpreter creation.
+ TFL_DeleteInterpreterOptions(options);
+ TFL_DeleteModel(model);
+
+ // Invoke the interpreter before tensor allocation.
+ EXPECT_EQ(TFL_InterpreterInvoke(interpreter), kTfLiteError);
+
+ // The error should propagate to the custom error reporter.
+ EXPECT_EQ(reporter.error_messages(),
+ "Invoke called on model that is not ready.");
+ EXPECT_EQ(reporter.num_calls(), 1);
+
+ TFL_DeleteInterpreter(interpreter);
+}
+
} // namespace
int main(int argc, char** argv) {
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
index 8442c4d46c..b1ebe4a804 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
index aa42b495bd..942dbbbeae 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
index e6d5a776b3..b35c6e0655 100644
--- a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
+++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <iostream>
#include <unordered_map>
#include <unordered_set>
-#include "flatbuffers/minireflect.h" // flatbuffers
+#include "flatbuffers/minireflect.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
index 52b17faf82..555a9cc4b0 100644
--- a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
@@ -117,6 +117,8 @@ Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
FlatBufferBuilder* fbb) {
+ // Initialized to -1.
+ // A value of -1 means this tensor will not be exported.
tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
std::vector<Offset<Tensor>> tensors;
@@ -135,15 +137,17 @@ Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
int curr_output_index = 0;
for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
tensor_index++) {
- if (!tensor_is_temporary[tensor_index]) {
+ // Temporary tensors and unused tensors will not be written.
+ if (!tensor_is_temporary[tensor_index] &&
+ unused_tensors_.find(tensor_index) == unused_tensors_.end()) {
tensor_to_written_tensor_[tensor_index] = curr_output_index++;
}
}
for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
++tensor_index) {
- // Skip temporaries.
- if (tensor_is_temporary[tensor_index]) continue;
+ // Tensor not exported.
+ if (tensor_to_written_tensor_[tensor_index] == -1) continue;
if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
// We only need to convert non temporaries
@@ -215,7 +219,9 @@ std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
std::vector<int> output;
output.reserve(input.size());
for (int x : input) {
- output.push_back(tensor_to_written_tensor_[x]);
+ if (tensor_to_written_tensor_[x] != -1) {
+ output.push_back(tensor_to_written_tensor_[x]);
+ }
}
return output;
}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
index a98108b496..a5f14697cf 100644
--- a/tensorflow/contrib/lite/experimental/writer/writer_lib.h
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
@@ -62,6 +62,10 @@ class InterpreterWriter {
// caller to change the custom data.
TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
CustomWriter custom_writer);
+ // Tensors that are unused and shouldn't be written.
+ void SetUnusedTensors(const std::set<int>& unused_tensors) {
+ unused_tensors_ = unused_tensors;
+ }
private:
template <class T>
@@ -111,8 +115,9 @@ class InterpreterWriter {
int builtin;
std::string custom;
};
+ std::set<int> unused_tensors_;
// For every tensor index in the interpreter, the index in the written.
- // This is different due to temporary tensors not being written.
+ // This is different due to temporary and unused tensors not being written.
std::vector<int> tensor_to_written_tensor_;
// List of used opcodes
std::vector<OpCode> opcodes_;
diff --git a/tensorflow/contrib/lite/g3doc/_book.yaml b/tensorflow/contrib/lite/g3doc/_book.yaml
index 1dffe30790..de6914e536 100644
--- a/tensorflow/contrib/lite/g3doc/_book.yaml
+++ b/tensorflow/contrib/lite/g3doc/_book.yaml
@@ -5,7 +5,7 @@ upper_tabs:
# Dropdown menu
- name: Ecosystem
path: /ecosystem
- is_default: True
+ is_default: true
menu:
- include: /ecosystem/_menu_toc.yaml
lower_tabs:
@@ -14,46 +14,50 @@ upper_tabs:
- name: Guide
contents:
- title: Overview
- path: /mobile/overview
- - title: Developer Guide
- path: /mobile/devguide
- - title: Android Demo App
- path: /mobile/demo_android
- - title: iOS Demo App
- path: /mobile/demo_ios
+ path: /lite/overview
+ - title: Developer guide
+ path: /lite/devguide
+ - title: Android demo app
+ path: /lite/demo_android
+ - title: iOS demo app
+ path: /lite/demo_ios
- title: Performance
- path: /mobile/performance
- - break: True
+ path: /lite/performance
+ - break: true
- title: TensorFlow Lite APIs
- path: /mobile/apis
+ path: /lite/apis
- title: Custom operators
- path: /mobile/custom_operators
- - title: TensorFlow Lite Ops Versioning
- path: /mobile/ops_versioning
- - title: TensorFlow Lite Compatibility Guide
- path: /mobile/tf_ops_compatibility
- - title: List of Hosted Models
- path: /mobile/models
+ path: /lite/custom_operators
+ - title: TensorFlow Lite ops versioning
+ path: /lite/ops_versioning
+ - title: TensorFlow Lite compatibility guide
+ path: /lite/tf_ops_compatibility
+ - title: List of hosted models
+ path: /lite/models
- title: TensorFlow Lite for iOS
- path: /mobile/ios
+ path: /lite/ios
- title: TensorFlow Lite for Raspberry Pi
- path: /mobile/rpi
+ path: /lite/rpi
- - heading: TF Mobile
+ - title: TF Mobile
+ style: accordion
status: deprecated
- - title: Overview
- path: /mobile/tfmobile/
- - title: Building TensorFlow on Android
- path: /mobile/tfmobile/android_build
- - title: Building TensorFlow on IOS
- path: /mobile/tfmobile/ios_build
- - title: Integrating TensorFlow libraries
- path: /mobile/tfmobile/linking_libs
- - title: Preparing models for mobile deployment
- path: /mobile/tfmobile/prepare_models
- - title: Optimizing for mobile
- path: /mobile/tfmobile/optimizing
+ section:
+ - title: Overview
+ path: /lite/tfmobile/
+ - title: Building TensorFlow on Android
+ path: /lite/tfmobile/android_build
+ - title: Building TensorFlow on IOS
+ path: /lite/tfmobile/ios_build
+ - title: Integrating TensorFlow libraries
+ path: /lite/tfmobile/linking_libs
+ - title: Preparing models for mobile deployment
+ path: /lite/tfmobile/prepare_models
+ - title: Optimizing for mobile
+ path: /lite/tfmobile/optimizing
- name: API
+ skip_translation: true
contents:
- - include: /mobile/api_docs/python/_toc.yaml
+ - title: API
+ path: /api_docs/python/tf/contrib/lite
diff --git a/tensorflow/contrib/lite/g3doc/_index.yaml b/tensorflow/contrib/lite/g3doc/_index.yaml
index 9119e49117..bc66cc5dc1 100644
--- a/tensorflow/contrib/lite/g3doc/_index.yaml
+++ b/tensorflow/contrib/lite/g3doc/_index.yaml
@@ -1,59 +1,209 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
+project_path: /lite/_project.yaml
+book_path: /lite/_book.yaml
description: <!--no description-->
landing_page:
+ custom_css_path: /site-assets/css/style.css
rows:
- - heading: TensorFlow Lite is a lightweight solution for mobile and embedded devices.
+ - heading: TensorFlow Lite is for mobile and embedded devices.
+ description: >
+ <p style="max-width: 75%;">
+ TensorFlow Lite is the official solution for running machine learning
+ models on mobile and embedded devices. It enables on&#8209;device machine
+ learning inference with low latency and a small binary size on Android,
+ iOS, and other operating systems.
+ </p>
+ <style>
+ .tfo-landing-row-heading {
+ padding-top: 0 !important;
+ }
+ .tfo-landing-row-heading h2 {
+ margin-top: 0 !important;
+ }
+ .tfo-landing-row-heading-list ol, .tfo-landing-row-heading-list ul {
+ margin-top: 0;
+ }
+ </style>
+
+ - classname: tfo-landing-row-heading tfo-landing-row-heading-list
+ heading: Many benefits
+ description: >
+ On-device ML inference is difficult because of the many constraints—TensorFlow Lite can solve these:
items:
- - description: >
- TensorFlow Lite is TensorFlow’s lightweight solution for mobile and
- embedded devices. It enables on-device machine learning inference with
- low latency and a small binary size. TensorFlow Lite also supports
- hardware acceleration with the
- <a href='https://developer.android.com/ndk/guides/neuralnetworks/index.html'>Android Neural Networks API</a>.
- list:
- - heading: Key point 1
+ - list:
+ - heading: Performance
+ description: >
+ TF Lite is fast with no noticeable accuracy loss—see the <a href="./performance">metrics</a>.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Portability
description: >
- [high-level overview]
+ <a href="https://developer.android.com/ndk/guides/neuralnetworks/" class="external">Android</a>,
+ iOS, and more specialized IoT devices.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- - heading: Key point 2
+ - list:
+ - heading: Low latency
description: >
- [high-level overview]
+ Optimized float- and fixed-point CPU kernels, op&#8209;fusing, and more.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- - heading: Key point 3
+ - heading: Acceleration
description: >
- [high-level overview]
+ Integration with GPU and internal/external accelerators.
icon:
- icon_name: chevron_right
+ icon_name: lens
foreground: theme
- background: grey
- - code_block: |
- <pre class = "prettyprint">
- $ toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
- --output_file=/tmp/mobilenet_v1_1.0_224.tflite \
- --inference_type=FLOAT \
- --input_type=FLOAT \
- --input_arrays=input \
- --output_arrays=MobilenetV1/Predictions/Reshape_1 \
- --input_shapes=1,224,224,3
- </pre>
+ - list:
+ - heading: Small model size
+ description: >
+ Controlled dependencies, <a href="https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3" class="external">quantization</a>,
+ and op&nbsp;registration.
+ icon:
+ icon_name: lens
+ foreground: theme
+ - heading: Tooling
+ description: >
+ Conversion, compression, benchmarking, power-consumption, and more.
+ icon:
+ icon_name: lens
+ foreground: theme
+
+ - classname: devsite-landing-row-logos tfo-landing-row-heading
+ heading: Companies using TensorFlow Lite
+ items:
+ - custom_image:
+ path: ./images/landing-page/photos_logo.png
+ path: https://www.photos.google.com
+ - custom_image:
+ path: ./images/landing-page/gboard_logo.png
+ path: https://play.google.com/store/apps/details?id=com.google.android.inputmethod.latin&hl=en_US
+ - custom_image:
+ path: ./images/landing-page/gmail_logo.png
+ path: https://www.google.com/gmail/
+ - custom_image:
+ path: ./images/landing-page/assistant_logo.png
+ path: https://assistant.google.com/
+
+ - classname: devsite-landing-row-logos
+ items:
+ - custom_image:
+ path: ./images/landing-page/vsco_logo.png
+ path: https://vsco.co
+ - custom_image:
+ path: ./images/landing-page/shazam_logo.png
+ path: https://www.shazam.com/
+ - custom_image:
+ path: ./images/landing-page/nest_logo.png
+ path: https://nest.com/
+ - custom_image:
+ path: ./images/landing-page/loseit_logo.png
+ path: https://www.loseit.com/
+
+ - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+ background: grey
+ items:
+ - description: >
+ <em>“TensorFlow Lite helped us introduce machine learning and AI into our
+ app in an easy and streamlined way. We could reduce the size of our
+ models while keeping the accuracy high. This helped us create an amazing
+ fishing experience for our users by allowing them to identify any fish
+ species with just a photo.”</em>
+ image_path: ./images/landing-page/fishbrain_logo_big.png
+
+ - heading: How it works
+ items:
+ - heading: Build
+ icon:
+ icon_name: build
+ description: >
+ Build a new model or retrain an existing one, such as using transfer learning.
+ buttons:
+ - label: Read the developer guide
+ path: /lite/devguide
+ classname: button button-primary tfo-button-primary
+ - heading: Convert
+ icon:
+ icon_name: autorenew
+ description: >
+ Convert a TensorFlow model into a compressed flat buffer with the
+ TensorFlow Lite Optimizing Converter (TOCO).
+ buttons:
+ - label: Read the TOCO guide
+ path: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/python_api.md
+ classname: button button-primary tfo-button-primary
+ - heading: Deploy
+ icon:
+ icon_name: bolt
+ description: >
+ Take the compressed <code>.tflite</code> file and load it into a mobile
+ or embedded device.<br/>
+ See the <a href="#build-your-first-tensorflow-lite-app">tutorials below</a> to build an app.
+
+ - heading: Build your first TensorFlow Lite app
+ background: grey
+ items:
+ - classname: tfo-landing-row-item-inset-white
+ heading: Get started
+ description: >
+ <ul>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/" class="external">TensorFlow for Poets</a></li>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-tflite/" class="external">TensorFlow for Poets 2: Android</a></li>
+ <li>Beginner: <a href="https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2-ios/" class="external">TensorFlow for Poets 2: iOS </a></li>
+ <li>Intermediate: <a href="https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193" class="external">Object detection tutorial</a>
+ </ul>
+ - classname: tfo-landing-row-item-inset-white
+ heading: Share your TensorFlow Lite story
+ description: >
+ We love to hear what you're working on—it may even get highlighted on
+ our social media! <a href="https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss" class="external">Tell us</a>.
+
+ - classname: devsite-landing-row-no-image-background devsite-landing-row-67
+ items:
+ - description: >
+ <p>
+ <em>“The release of TensorFlow Lite has allowed us to deploy an engaging
+ real-time experience to our users that eliminates the requirement
+ for a data connection. TensorFlow Lite’s ability to compress and
+ optimize the TensorFlow graph for mobile deployment has been
+ transformative in expanding the capabilities of Snap It.</em>
+ </p>
+ <p>
+ <em>Through TensorFlow Lite, our users can now enjoy a state of the
+ art, computer-vision-based food logging experience without worrying
+ about signal strength. We look forward to future collaborations
+ with the TensorFlow Lite team.”</em>
+ </p>
+ image_path: ./images/landing-page/loseit_logo_big.png
- classname: devsite-landing-row-cards
+ background: grey
+ heading: Updates
items:
+ - heading: Introducing the Model Optimization Toolkit
+ image_path: /ecosystem/images/tf-logo-card-16x9.png
+ path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+ buttons:
+ - label: Read on TensorFlow blog
+ path: https://medium.com/tensorflow/introducing-the-model-optimization-toolkit-for-tensorflow-254aca1ba0a3
+ - heading: East Africa Cassava App
+ image_path: ./images/landing-page/detect_crop_disease_in_africa.png
+ path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
+ buttons:
+ - label: Read more
+ path: https://heartbeat.fritz.ai/community-spotlight-nuru-a-mobile-app-by-plantvillage-to-detect-crop-disease-in-africa-28d142bf63d5
- heading: Using TensorFlow Lite on Android
image_path: /ecosystem/images/tf-logo-card-16x9.png
path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
buttons:
- label: Read on TensorFlow blog
path: https://medium.com/tensorflow/using-tensorflow-lite-on-android-9bbc9cb7d69d
+
+ - classname: devsite-landing-row-cards
+ background: grey
+ items:
- heading: TensorFlow Lite at the Dev Summit
youtube_id: FAMfy7izB6A
buttons:
@@ -65,3 +215,4 @@ landing_page:
buttons:
- label: View on GitHub
path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite
+ - classname: devsite-landing-row-item-hidden
diff --git a/tensorflow/contrib/lite/g3doc/_project.yaml b/tensorflow/contrib/lite/g3doc/_project.yaml
index b39666516b..3ce6986396 100644
--- a/tensorflow/contrib/lite/g3doc/_project.yaml
+++ b/tensorflow/contrib/lite/g3doc/_project.yaml
@@ -1,10 +1,10 @@
name: TensorFlow Lite
-breadcrumb_name: Mobile
-home_url: /mobile/
+breadcrumb_name: TensorFlow Lite
+home_url: /lite/
parent_project_metadata_path: /_project.yaml
description: >
TensorFlow Lite is a lightweight solution for mobile and embedded devices.
-use_site_branding: True
-hide_from_products_list: True
+use_site_branding: true
+hide_from_products_list: true
content_license: cc3-apache2
buganizer_id: 316308
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml b/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
deleted file mode 100644
index 1e1c44c692..0000000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/_toc.yaml
+++ /dev/null
@@ -1,6 +0,0 @@
-# Automatically generated file; please do not edit
-toc:
- - title: TensorFlow Lite
- section:
- - title: Overview
- path: /mobile/api_docs/python/
diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md
index 90e7915c52..0eed516000 100644
--- a/tensorflow/contrib/lite/g3doc/devguide.md
+++ b/tensorflow/contrib/lite/g3doc/devguide.md
@@ -1,5 +1,4 @@
-
-# Developer Guide
+# TF Lite Developer Guide
Using a TensorFlow Lite model in your mobile app requires multiple
considerations: you must choose a pre-trained or custom model, convert the model
@@ -55,7 +54,7 @@ both floating point and quantized inference.
### Train a custom model
A developer may choose to train a custom model using Tensorflow (see the
-[TensorFlow tutorials](../../tutorials/) for examples of building and training
+[TensorFlow tutorials](../tutorials/) for examples of building and training
models). If you have already written a model, the first step is to export this
to a `tf.GraphDef` file. This is required because some formats do not store the
model structure outside the code, and we must communicate with other parts of the
@@ -205,7 +204,7 @@ The open source Android demo app uses the JNI interface and is available
[on GitHub](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app).
You can also download a
[prebuilt APK](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
-See the <a href="../demo_android.md">Android demo</a> guide for details.
+See the <a href="./demo_android.md">Android demo</a> guide for details.
The <a href="./android_build.md">Android mobile</a> guide has instructions for
installing TensorFlow on Android and setting up `bazel` and Android Studio.
@@ -214,7 +213,7 @@ installing TensorFlow on Android and setting up `bazel` and Android Studio.
To integrate a TensorFlow model in an iOS app, see the
[TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md)
-guide and <a href="../demo_ios.md">iOS demo</a> guide.
+guide and <a href="./demo_ios.md">iOS demo</a> guide.
#### Core ML support
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
new file mode 100644
index 0000000000..ced0872ab2
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/assistant_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
new file mode 100644
index 0000000000..45b3b4f6fe
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/detect_crop_disease_in_africa.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
new file mode 100644
index 0000000000..bc1bf6e1e7
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
new file mode 100644
index 0000000000..d76fca86a9
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/fishbrain_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
new file mode 100644
index 0000000000..f1a93ab763
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gboard_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
new file mode 100644
index 0000000000..21aa2c84ea
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/gmail_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
new file mode 100644
index 0000000000..b6b3d14df9
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
new file mode 100644
index 0000000000..b3e46d4bd8
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/loseit_logo_big.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
new file mode 100644
index 0000000000..35bfd97373
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/nest_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
new file mode 100644
index 0000000000..4333426dfe
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/photos_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
new file mode 100644
index 0000000000..6ec412c75c
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/shazam_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
new file mode 100644
index 0000000000..f408f9024b
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/images/landing-page/vsco_logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md
index a83d2c8fec..3b9fcca811 100644
--- a/tensorflow/contrib/lite/g3doc/ios.md
+++ b/tensorflow/contrib/lite/g3doc/ios.md
@@ -1,5 +1,10 @@
-# TensorFlow Lite for iOS
+# Build TensorFlow Lite for iOS
+
+This document describes how to build TensorFlow Lite iOS library. If you just
+want to use it, the easiest way is using the TensorFlow Lite CocoaPod releases.
+See [TensorFlow Lite iOS Demo](demo_ios.md) for examples.
+
## Building
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index a4267eee4c..279764ce96 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -1,6 +1,23 @@
# List of Hosted Models
+# AutoML mobile image classification models (Float Models)
+
+Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^
+------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ---------------------:
+MnasNet_0.50_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.5_224_09_07_2018.tgz) | 8.5 Mb | 68.03% | 87.79% | 37 ms
+MnasNet_0.75_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_0.75_224_09_07_2018.tgz) | 12 Mb | 71.72% | 90.17% | 61 ms
+MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms
+MnasNet_1.3_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.3_224_09_07_2018.tgz) | 24 Mb | 75.24% | 92.55% | 152 ms
+MnasNet_1.0_96| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_96_09_07_2018.tgz) | 17 Mb | 62.33% | 83.98% | 23 ms
+MnasNet_1.0_128| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_128_09_07_2018.tgz) | 17 Mb | 67.32% | 87.70% | 34 ms
+MnasNet_1.0_160| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_160_09_07_2018.tgz) | 17 Mb | 70.63% | 89.58% | 51 ms
+MnasNet_1.0_192| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_192_09_07_2018.tgz) | 17 Mb | 72.56% | 90.76% | 70 ms
+MnasNet_1.0_224| [paper](https://arxiv.org/abs/1807.11626), [tflite&pb](https://storage.cloud.google.com/download.tensorflow.org/models/tflite/mnasnet_1.0_224_09_07_2018.tgz) | 17 Mb | 74.08% | 91.75% | 93 ms
+
+^ Performance numbers are generated on Pixel-1 using single thread large BIG core.
+
+
## Image classification (Float Models)
Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md
index 8cf43496df..9d035a6921 100644
--- a/tensorflow/contrib/lite/g3doc/overview.md
+++ b/tensorflow/contrib/lite/g3doc/overview.md
@@ -25,7 +25,7 @@ models.
TensorFlow Lite defines a new model file format, based on
[FlatBuffers](https://google.github.io/flatbuffers/). FlatBuffers is an
-open-sourced, efficient cross platform serialization library. It is similar to
+efficient open-source cross-platform serialization library. It is similar to
[protocol buffers](https://developers.google.com/protocol-buffers/?hl=en), but
the primary difference is that FlatBuffers does not need a parsing/unpacking
step to a secondary representation before you can access data, often coupled
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 28cb6aba6e..0ae9400068 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -1,174 +1,38 @@
-# Performance
+# Performance best practices
-This document lists TensorFlow Lite performance benchmarks when running well
-known models on some Android and iOS devices.
+Mobile and embedded devices have limited computational resources and it is important to keep your application resource efficient. We have compiled a list of best practices and strategies you can use to optimize your model and application when using Tensorflow Lite.
-These performance benchmark numbers were generated with the
-[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark)
-and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+## Choose the most efficient model for the problem
+Some models may be too large to run on embedded devices. Instead of large models it is better to use a slightly less precise but smaller model for embedded devices. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices.
-# Android performance benchmarks
+You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for
+[image classification] (https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and
+ [object detection](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193).
-For Android benchmarks, the CPU affinity is set to use big cores on the device to
-reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
-It assumes that models were download and unzipped to the
-`/data/local/tmp/tflite_models` directory. The benchmark binary is built
-using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android)
-and assumed in the `/data/local/tmp` directory.
+## Profile your model
+Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](../tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
-To run the benchmark:
+## Profile and optimize operators in the graph
+If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator.
+ This scenario should be rare as Tensorflow Lite has optimized versions for most ops. However you may be able to write a faster version of a custom op, if you know the constraints in which the operator is executed. Check out our [custom operator documentation](custom_operators.md).
-```
-adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \
- --num_threads=1 \
- --graph=/data/local/tmp/tflite_models/${GRAPH} \
- --warmup_runs=1 \
- --num_runs=50 \
- --use_nnapi=false
-```
+## Quantize your model
+If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. Fully quantized models can be remarkably power efficient as well.
-Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity
-chosen according to the following table:
+## Tweak the number of threads
+Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](../interpreter.h) threads.
-Device | CPU_MASK |
--------| ----------
-Pixel 2 | f0 |
-Pixel xl | 0c |
+## Eliminate redundant copies
+Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to [mmap a model file](https://github.com/tensorflow/tensorflow/blob/9982fd6c8831cbd2f58954f79ea71f26660393bc/tensorflow/contrib/lite/model.h#L152) and avoid copies. If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151).
-<table>
- <thead>
- <tr>
- <th>Model Name</th>
- <th>Device </th>
- <th>Mean inference time (std dev)</th>
- </tr>
- </thead>
- <tr>
- <td rowspan = 2>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
- </td>
- <td>Pixel 2 </td>
- <td>166.5 ms (2.6 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>122.9 ms (1.8 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
- </td>
- <td>Pixel 2 </td>
- <td>69.5 ms (0.9 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>78.9 ms (2.2 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
- </td>
- <td>Pixel 2 </td>
- <td>273.8 ms (3.5 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>210.8 ms (4.2 ms)</td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
- </td>
- <td>Pixel 2 </td>
- <td>234.0 ms (2.1 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>158.0 ms (2.1 ms)</td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
- </td>
- <td>Pixel 2 </td>
- <td>2846.0 ms (15.0 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>1973.0 ms (15.0 ms) </td>
- </tr>
- <tr>
- <td rowspan = 2>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
- </td>
- <td>Pixel 2 </td>
- <td>3180.0 ms (11.7 ms)</td>
- </tr>
- <tr>
- <td>Pixel xl </td>
- <td>2262.0 ms (21.0 ms) </td>
- </tr>
+## Profile your application with platform specific tools
+Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform.
- </table>
+## Use hardware accelerators available on the device
+Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/) on Android.
+You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable NNAPI call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance.
-# iOS benchmarks
-
-To run iOS benchmarks, the [benchmark
-app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
-was modified to include the appropriate model and `benchmark_params.json` was
-modified to set `num_threads` to 1.
-
-<table>
- <thead>
- <tr>
- <th>Model Name</th>
- <th>Device </th>
- <th>Mean inference time (std dev)</th>
- </tr>
- </thead>
- <tr>
- <td>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
- </td>
- <td>iPhone 8 </td>
- <td>32.2 ms (0.8 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
- </td>
- <td>iPhone 8 </td>
- <td>24.4 ms (0.8 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
- </td>
- <td>iPhone 8 </td>
- <td>60.3 ms (0.6 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
- </td>
- <td>iPhone 8 </td>
- <td>44.3 (0.7 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
- </td>
- <td>iPhone 8</td>
- <td>562.4 ms (18.2 ms)</td>
- </tr>
- <tr>
- <td>
- <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
- </td>
- <td>iPhone 8 </td>
- <td>661.0 ms (29.2 ms)</td>
- </tr>
- </table>
+## Need more help
+The Tensorflow team is happy to help diagnose and address specific performance issues you may be facing. Please file a bug on [github](https://github.com/tensorflow/tensorflow/issues) with details of the issue.
diff --git a/tensorflow/contrib/lite/g3doc/performance_benchmarks.md b/tensorflow/contrib/lite/g3doc/performance_benchmarks.md
new file mode 100644
index 0000000000..28cb6aba6e
--- /dev/null
+++ b/tensorflow/contrib/lite/g3doc/performance_benchmarks.md
@@ -0,0 +1,174 @@
+
+# Performance
+
+This document lists TensorFlow Lite performance benchmarks when running well
+known models on some Android and iOS devices.
+
+These performance benchmark numbers were generated with the
+[Android TFLite benchmark binary](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark)
+and the [iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+
+# Android performance benchmarks
+
+For Android benchmarks, the CPU affinity is set to use big cores on the device to
+reduce variance (see [details](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#reducing-variance-between-runs-on-android)).
+
+It assumes that models were download and unzipped to the
+`/data/local/tmp/tflite_models` directory. The benchmark binary is built
+using [these instructions](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark#on-android)
+and assumed in the `/data/local/tmp` directory.
+
+To run the benchmark:
+
+```
+adb shell taskset ${CPU_MASK} /data/local/tmp/benchmark_model \
+ --num_threads=1 \
+ --graph=/data/local/tmp/tflite_models/${GRAPH} \
+ --warmup_runs=1 \
+ --num_runs=50 \
+ --use_nnapi=false
+```
+
+Here, `${GRAPH}` is the name of model and `${CPU_MASK}` is the CPU affinity
+chosen according to the following table:
+
+Device | CPU_MASK |
+-------| ----------
+Pixel 2 | f0 |
+Pixel xl | 0c |
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>166.5 ms (2.6 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>122.9 ms (1.8 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>69.5 ms (0.9 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>78.9 ms (2.2 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>273.8 ms (3.5 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>210.8 ms (4.2 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>234.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>158.0 ms (2.1 ms)</td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>2846.0 ms (15.0 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>1973.0 ms (15.0 ms) </td>
+ </tr>
+ <tr>
+ <td rowspan = 2>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>Pixel 2 </td>
+ <td>3180.0 ms (11.7 ms)</td>
+ </tr>
+ <tr>
+ <td>Pixel xl </td>
+ <td>2262.0 ms (21.0 ms) </td>
+ </tr>
+
+ </table>
+
+# iOS benchmarks
+
+To run iOS benchmarks, the [benchmark
+app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios)
+was modified to include the appropriate model and `benchmark_params.json` was
+modified to set `num_threads` to 1.
+
+<table>
+ <thead>
+ <tr>
+ <th>Model Name</th>
+ <th>Device </th>
+ <th>Mean inference time (std dev)</th>
+ </tr>
+ </thead>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz">Mobilenet_1.0_224(float)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>32.2 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz)">Mobilenet_1.0_224 (quant)</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>24.4 ms (0.8 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz">NASNet mobile</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>60.3 ms (0.6 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz">SqueezeNet</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>44.3 (0.7 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz">Inception_ResNet_V2</a>
+ </td>
+ <td>iPhone 8</td>
+ <td>562.4 ms (18.2 ms)</td>
+ </tr>
+ <tr>
+ <td>
+ <a href="https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz">Inception_V4</a>
+ </td>
+ <td>iPhone 8 </td>
+ <td>661.0 ms (29.2 ms)</td>
+ </tr>
+ </table>
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 8660d29855..b0dfb0fed1 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -866,6 +866,17 @@ Outputs {
}
```
+**ZEROS_LIKE**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: A tensor of the same shape and type as x but filled with zeros
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index c7cdee07de..b0f32a8d6c 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -93,7 +93,7 @@ requires some knowledge of build systems and Android developer tools, but we'll
guide you through the basics here.
- First, follow our instructions for
- <a href="http://www.tensorflow.org/install/install_sources">installing from sources</a>.
+ <a href="http://www.tensorflow.org/install/source">installing from sources</a>.
This will also guide you through installing Bazel and cloning the
TensorFlow code.
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index d003bb2f38..49ad35d4e6 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -4,7 +4,7 @@
TensorFlow was designed to be a good deep learning solution for mobile
platforms. Currently we have two solutions for deploying machine learning
applications on mobile and embedded devices: TensorFlow for Mobile and
-<a href="../index.md">TensorFlow Lite</a>.
+<a href="../../lite">TensorFlow Lite</a>.
## TensorFlow Lite versus TensorFlow Mobile
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 3f8f4d198f..2657bcd42b 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -123,6 +123,7 @@ Interpreter::Interpreter(ErrorReporter* error_reporter)
context_.AddTensors = AddTensors;
context_.tensors = nullptr;
context_.tensors_size = 0;
+ context_.allow_fp32_relax_to_fp16 = false;
context_.recommended_num_threads = -1;
context_.GetExternalContext = GetExternalContext;
context_.SetExternalContext = SetExternalContext;
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index f0cd178c19..aa2bc4def6 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -336,6 +336,19 @@ class Interpreter {
// Set the number of threads available to the interpreter.
void SetNumThreads(int num_threads);
+ // Allow float16 precision for FP32 calculation when possible.
+ // default: not allow.
+ // WARNING: This is an experimental API and subject to change.
+ void SetAllowFp16PrecisionForFp32(bool allow) {
+ context_.allow_fp32_relax_to_fp16 = allow;
+ }
+
+ // Get the half precision flag.
+ // WARNING: This is an experimental API and subject to change.
+ bool GetAllowFp16PrecisionForFp32() const {
+ return context_.allow_fp32_relax_to_fp16;
+ }
+
// Allow a delegate to look at the graph and modify the graph to handle
// parts of the graph themselves. After this is called, the graph may
// contain new nodes that replace 1 more nodes.
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index 6a3f0651d0..c04b2a6194 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -1,4 +1,6 @@
-# TF Lite Android App
+# TF Lite Android Image Classifier App Example
+
+A simple Android example that demonstrates image classification using the camera.
## Building in Android Studio with TensorFlow Lite AAR from JCenter.
The build.gradle is configured to use TensorFlow Lite's nightly build.
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
index 4f5662bc2d..3596e42011 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -58,9 +58,9 @@ import android.view.View;
import android.view.ViewGroup;
import android.widget.CompoundButton;
import android.widget.NumberPicker;
-import android.widget.ToggleButton;
import android.widget.TextView;
import android.widget.Toast;
+import android.widget.ToggleButton;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@@ -305,22 +305,24 @@ public class Camera2BasicFragment extends Fragment
textView = (TextView) view.findViewById(R.id.text);
toggle = (ToggleButton) view.findViewById(R.id.button);
- toggle.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() {
- public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
- classifier.setUseNNAPI(isChecked);
- }
- });
+ toggle.setOnCheckedChangeListener(
+ new CompoundButton.OnCheckedChangeListener() {
+ public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) {
+ backgroundHandler.post(() -> classifier.setUseNNAPI(isChecked));
+ }
+ });
np = (NumberPicker) view.findViewById(R.id.np);
np.setMinValue(1);
np.setMaxValue(10);
np.setWrapSelectorWheel(true);
- np.setOnValueChangedListener(new NumberPicker.OnValueChangeListener() {
- @Override
- public void onValueChange(NumberPicker picker, int oldVal, int newVal){
- classifier.setNumThreads(newVal);
- }
- });
+ np.setOnValueChangedListener(
+ new NumberPicker.OnValueChangeListener() {
+ @Override
+ public void onValueChange(NumberPicker picker, int oldVal, int newVal) {
+ backgroundHandler.post(() -> classifier.setNumThreads(newVal));
+ }
+ });
}
/** Load the model and labels. */
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
index 7bb6afd9d8..2d11a57434 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -59,9 +59,15 @@ public abstract class ImageClassifier {
private static final int DIM_PIXEL_SIZE = 3;
- /* Preallocated buffers for storing image data in. */
+ /** Preallocated buffers for storing image data in. */
private int[] intValues = new int[getImageSizeX() * getImageSizeY()];
+ /** Options for configuring the Interpreter. */
+ private final Interpreter.Options tfliteOptions = new Interpreter.Options();
+
+ /** The loaded TensorFlow Lite model. */
+ private MappedByteBuffer tfliteModel;
+
/** An instance of the driver class to run model inference with Tensorflow Lite. */
protected Interpreter tflite;
@@ -89,7 +95,8 @@ public abstract class ImageClassifier {
/** Initializes an {@code ImageClassifier}. */
ImageClassifier(Activity activity) throws IOException {
- tflite = new Interpreter(loadModelFile(activity));
+ tfliteModel = loadModelFile(activity);
+ tflite = new Interpreter(tfliteModel, tfliteOptions);
labelList = loadLabelList(activity);
imgData =
ByteBuffer.allocateDirect(
@@ -150,20 +157,28 @@ public abstract class ImageClassifier {
}
}
+ private void recreateInterpreter() {
+ if (tflite != null) {
+ tflite.close();
+ tflite = new Interpreter(tfliteModel, tfliteOptions);
+ }
+ }
+
public void setUseNNAPI(Boolean nnapi) {
- if (tflite != null)
- tflite.setUseNNAPI(nnapi);
+ tfliteOptions.setUseNNAPI(nnapi);
+ recreateInterpreter();
}
- public void setNumThreads(int num_threads) {
- if (tflite != null)
- tflite.setNumThreads(num_threads);
+ public void setNumThreads(int numThreads) {
+ tfliteOptions.setNumThreads(numThreads);
+ recreateInterpreter();
}
/** Closes tflite to release resources. */
public void close() {
tflite.close();
tflite = null;
+ tfliteModel = null;
}
/** Reads label list from Assets. */
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 781289ceb2..bb0be04ca2 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -44,6 +44,7 @@ java_binary(
android_library(
name = "ovicbenchmarkerlib",
srcs = [
+ "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java",
"src/main/java/org/tensorflow/ovic/OvicClassifier.java",
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md
index 26349347fa..df77bfaab3 100644
--- a/tensorflow/contrib/lite/java/ovic/README.md
+++ b/tensorflow/contrib/lite/java/ovic/README.md
@@ -4,7 +4,7 @@ This folder contains building code for track one of the [Low Power ImageNet Reco
## Pre-requisite
-Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
+Follow the steps [here](https://www.tensorflow.org/lite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK.
## Test the benchmarker:
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index a8d751ade2..b2e3a9bd7d 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -6,7 +6,6 @@ licenses(["notice"]) # Apache 2.0
android_binary(
name = "ovic_benchmarker_binary",
srcs = [
- "OvicBenchmarker.java",
"OvicBenchmarkerActivity.java",
],
assets = [
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
index 59457c308a..4adf94aeb6 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
@@ -34,8 +34,10 @@ import java.io.InputStream;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.text.DecimalFormat;
+import org.tensorflow.ovic.OvicBenchmarker;
import org.tensorflow.ovic.OvicSingleImageResult;
+
/** Class that benchmark image classifier models. */
public class OvicBenchmarkerActivity extends Activity {
/** Tag for the {@link Log}. */
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
index 113ab74a20..4cda258bee 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-package ovic.demo.app;
+package org.tensorflow.ovic;
import android.graphics.Bitmap;
import android.os.SystemClock;
@@ -22,8 +22,6 @@ import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
-import org.tensorflow.ovic.OvicClassifier;
-import org.tensorflow.ovic.OvicSingleImageResult;
/**
* Class that benchmarks image classifier models.
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
index 4cf51bb0fa..fd610b054f 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
@@ -74,7 +74,7 @@ public class OvicClassifier {
}
labelList = loadLabelList(labelInputStream);
// OVIC uses one thread for CPU inference.
- tflite = new Interpreter(model, 1);
+ tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1));
inputDims = TestHelper.getInputDims(tflite, 0);
if (inputDims.length != 4) {
throw new RuntimeException("The model's input dimensions must be 4 (BWHC).");
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index b84720ae8e..ffb04496cb 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -17,7 +17,6 @@ package org.tensorflow.lite;
import java.io.File;
import java.nio.ByteBuffer;
-import java.nio.MappedByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.checkerframework.checker.nullness.qual.NonNull;
@@ -56,16 +55,36 @@ import org.checkerframework.checker.nullness.qual.NonNull;
*/
public final class Interpreter implements AutoCloseable {
+ /** An options class for controlling runtime interpreter behavior. */
+ public static class Options {
+ public Options() {}
+
+ /**
+ * Sets the number of threads to be used for ops that support multi-threading. Defaults to a
+ * platform-dependent value.
+ */
+ public Options setNumThreads(int numThreads) {
+ this.numThreads = numThreads;
+ return this;
+ }
+
+ /** Sets whether to use NN API (if available) for op execution. Defaults to false (disabled). */
+ public Options setUseNNAPI(boolean useNNAPI) {
+ this.useNNAPI = useNNAPI;
+ return this;
+ }
+
+ int numThreads = -1;
+ boolean useNNAPI = false;
+ }
+
/**
* Initializes a {@code Interpreter}
*
* @param modelFile: a File of a pre-trained TF Lite model.
*/
public Interpreter(@NonNull File modelFile) {
- if (modelFile == null) {
- return;
- }
- wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath());
+ this(modelFile, /*options = */ null);
}
/**
@@ -73,12 +92,22 @@ public final class Interpreter implements AutoCloseable {
*
* @param modelFile: a file of a pre-trained TF Lite model
* @param numThreads: number of threads to use for inference
+ * @deprecated Prefer using the {@link #Interpreter(File,Options)} constructor. This method will
+ * be removed in a future release.
*/
+ @Deprecated
public Interpreter(@NonNull File modelFile, int numThreads) {
- if (modelFile == null) {
- return;
- }
- wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), numThreads);
+ this(modelFile, new Options().setNumThreads(numThreads));
+ }
+
+ /**
+ * Initializes a {@code Interpreter} and specifies the number of threads used for inference.
+ *
+ * @param modelFile: a file of a pre-trained TF Lite model
+ * @param options: a set of options for customizing interpreter behavior
+ */
+ public Interpreter(@NonNull File modelFile, Options options) {
+ wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
}
/**
@@ -89,7 +118,7 @@ public final class Interpreter implements AutoCloseable {
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
public Interpreter(@NonNull ByteBuffer byteBuffer) {
- wrapper = new NativeInterpreterWrapper(byteBuffer);
+ this(byteBuffer, /* options= */ null);
}
/**
@@ -99,30 +128,25 @@ public final class Interpreter implements AutoCloseable {
* <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
* {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
* direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
- */
- public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) {
- wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads);
- }
-
- /**
- * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file.
*
- * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
- * Interpreter}.
+ * @deprecated Prefer using the {@link #Interpreter(ByteBuffer,Options)} constructor. This method
+ * will be removed in a future release.
*/
- public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) {
- wrapper = new NativeInterpreterWrapper(mappedByteBuffer);
+ @Deprecated
+ public Interpreter(@NonNull ByteBuffer byteBuffer, int numThreads) {
+ this(byteBuffer, new Options().setNumThreads(numThreads));
}
/**
- * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and
- * specifies the number of threads used for inference.
+ * Initializes a {@code Interpreter} with a {@code ByteBuffer} of a model file and a set of custom
+ * {@link #Options}.
*
- * <p>The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code
- * Interpreter}.
+ * <p>The ByteBuffer should not be modified after the construction of a {@code Interpreter}. The
+ * {@code ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a
+ * direct {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
*/
- public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) {
- wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads);
+ public Interpreter(@NonNull ByteBuffer byteBuffer, Options options) {
+ wrapper = new NativeInterpreterWrapper(byteBuffer, options);
}
/**
@@ -240,12 +264,25 @@ public final class Interpreter implements AutoCloseable {
return wrapper.getLastNativeInferenceDurationNanoseconds();
}
- /** Turns on/off Android NNAPI for hardware acceleration when it is available. */
+ /**
+ * Turns on/off Android NNAPI for hardware acceleration when it is available.
+ *
+ * @deprecated Prefer using {@link Options#setUseNNAPI(boolean)} directly for enabling NN API.
+ * This method will be removed in a future release.
+ */
+ @Deprecated
public void setUseNNAPI(boolean useNNAPI) {
checkNotClosed();
wrapper.setUseNNAPI(useNNAPI);
}
+ /**
+ * Sets the number of threads to be used for ops that support multi-threading.
+ *
+ * @deprecated Prefer using {@link Options#setNumThreads(int)} directly for controlling thread
+ * multi-threading. This method will be removed in a future release.
+ */
+ @Deprecated
public void setNumThreads(int numThreads) {
checkNotClosed();
wrapper.setNumThreads(numThreads);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index fa25082304..6feff9a618 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -23,7 +23,7 @@ import java.util.HashMap;
import java.util.Map;
/**
- * A wrapper wraps native interpreter and controls model execution.
+ * An internal wrapper that wraps native interpreter and controls model execution.
*
* <p><b>WARNING:</b> Resources consumed by the {@code NativeInterpreterWrapper} object must be
* explicitly freed by invoking the {@link #close()} method when the {@code
@@ -32,36 +32,29 @@ import java.util.Map;
final class NativeInterpreterWrapper implements AutoCloseable {
NativeInterpreterWrapper(String modelPath) {
- this(modelPath, /* numThreads= */ -1);
+ this(modelPath, /* options= */ null);
}
- NativeInterpreterWrapper(String modelPath, int numThreads) {
+ NativeInterpreterWrapper(String modelPath, Interpreter.Options options) {
+ if (options == null) {
+ options = new Interpreter.Options();
+ }
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModel(modelPath, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
}
- /**
- * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer}. The ByteBuffer should
- * not be modified after the construction of a {@code NativeInterpreterWrapper}. The {@code
- * ByteBuffer} can be either a {@code MappedByteBuffer} that memory-maps a model file, or a direct
- * {@code ByteBuffer} of nativeOrder() that contains the bytes content of a model.
- */
NativeInterpreterWrapper(ByteBuffer byteBuffer) {
- this(byteBuffer, /* numThreads= */ -1);
+ this(byteBuffer, /* options= */ null);
}
- /**
- * Initializes a {@code NativeInterpreterWrapper} with a {@code ByteBuffer} and specifies the
- * number of inference threads. The ByteBuffer should not be modified after the construction of a
- * {@code NativeInterpreterWrapper}. The {@code ByteBuffer} can be either a {@code
- * MappedByteBuffer} that memory-maps a model file, or a direct {@code ByteBuffer} of
- * nativeOrder() that contains the bytes content of a model.
- */
- NativeInterpreterWrapper(ByteBuffer buffer, int numThreads) {
+ NativeInterpreterWrapper(ByteBuffer buffer, Interpreter.Options options) {
+ if (options == null) {
+ options = new Interpreter.Options();
+ }
if (buffer == null
|| (!(buffer instanceof MappedByteBuffer)
&& (!buffer.isDirect() || buffer.order() != ByteOrder.nativeOrder()))) {
@@ -72,10 +65,13 @@ final class NativeInterpreterWrapper implements AutoCloseable {
modelByteBuffer = buffer;
errorHandle = createErrorReporter(ERROR_BUFFER_SIZE);
modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle);
- interpreterHandle = createInterpreter(modelHandle, errorHandle, numThreads);
+ interpreterHandle = createInterpreter(modelHandle, errorHandle, options.numThreads);
isMemoryAllocated = true;
inputTensors = new Tensor[getInputCount(interpreterHandle)];
outputTensors = new Tensor[getOutputCount(interpreterHandle)];
+ if (options.useNNAPI) {
+ setUseNNAPI(options.useNNAPI);
+ }
}
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index 9070b788b6..dfdd7d22b0 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -55,6 +55,18 @@ public final class InterpreterTest {
}
@Test
+ public void testInterpreterWithOptions() throws Exception {
+ Interpreter interpreter =
+ new Interpreter(MODEL_FILE, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
+ assertThat(interpreter).isNotNull();
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(1);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.close();
+ }
+
+ @Test
public void testRunWithMappedByteBufferModel() throws Exception {
Path path = MODEL_FILE.toPath();
FileChannel fileChannel =
@@ -304,40 +316,14 @@ public final class InterpreterTest {
}
@Test
- public void testTurnOffNNAPI() throws Exception {
- Path path = MODEL_FILE.toPath();
- FileChannel fileChannel =
- (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
- MappedByteBuffer mappedByteBuffer =
- fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
- Interpreter interpreter = new Interpreter(mappedByteBuffer);
- interpreter.setUseNNAPI(true);
- float[] oneD = {1.23f, 6.54f, 7.81f};
- float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
- float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
- float[][][][] fourD = {threeD, threeD};
- float[][][][] parsedOutputs = new float[2][8][8][3];
- interpreter.run(fourD, parsedOutputs);
- float[] outputOneD = parsedOutputs[0][0][0];
- float[] expected = {3.69f, 19.62f, 23.43f};
- assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- interpreter.setUseNNAPI(false);
- interpreter.run(fourD, parsedOutputs);
- outputOneD = parsedOutputs[0][0][0];
- assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder();
- interpreter.close();
- fileChannel.close();
- }
-
- @Test
public void testTurnOnNNAPI() throws Exception {
Path path = MODEL_FILE.toPath();
FileChannel fileChannel =
(FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ));
MappedByteBuffer mappedByteBuffer =
fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
- Interpreter interpreter = new Interpreter(mappedByteBuffer);
- interpreter.setUseNNAPI(true);
+ Interpreter interpreter =
+ new Interpreter(mappedByteBuffer, new Interpreter.Options().setUseNNAPI(true));
float[] oneD = {1.23f, 6.54f, 7.81f};
float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD};
float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD};
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 9c4a5acd79..270bd6703a 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -63,6 +63,15 @@ public final class NativeInterpreterWrapperTest {
}
@Test
+ public void testConstructorWithOptions() {
+ NativeInterpreterWrapper wrapper =
+ new NativeInterpreterWrapper(
+ FLOAT_MODEL_PATH, new Interpreter.Options().setNumThreads(2).setUseNNAPI(true));
+ assertThat(wrapper).isNotNull();
+ wrapper.close();
+ }
+
+ @Test
public void testConstructorWithInvalidModel() {
try {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH);
diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
index 38b740021b..af20e3280b 100644
--- a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
+++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java
@@ -19,21 +19,6 @@ package org.tensorflow.lite;
public class TestHelper {
/**
- * Turns on/off NNAPI of an {@code Interpreter}.
- *
- * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code
- * IllegalArgumentException} will be thrown.
- * @param useNNAPI a boolean value indicating to turn on or off NNAPI.
- */
- public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) {
- if (interpreter != null && interpreter.wrapper != null) {
- interpreter.wrapper.setUseNNAPI(useNNAPI);
- } else {
- throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI.");
- }
- }
-
- /**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
*
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 40f28aeab4..daaf6714cc 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -223,6 +223,7 @@ cc_library(
"unidirectional_sequence_lstm.cc",
"unidirectional_sequence_rnn.cc",
"unpack.cc",
+ "zeros_like.cc",
],
hdrs = [
],
@@ -508,6 +509,7 @@ tf_cc_test(
":builtin_ops",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_absl//absl/memory",
"@com_google_googletest//:gtest",
],
)
@@ -1284,6 +1286,20 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "zeros_like_test",
+ size = "small",
+ srcs = ["zeros_like_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 44ef587244..0d2d5e775f 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
index 7346b9fd80..7e4ff6fc16 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index ab6bdaecaa..101b4fc961 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -414,35 +414,57 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
switch (effective_kernel_type) {
- case kReference:
+ case kReference: {
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
reference_ops::Conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- params->dilation_width_factor, params->dilation_height_factor,
- data->padding.width, data->padding.height, output_offset,
- data->output_multiplier, data->output_shift,
- data->output_activation_min, data->output_activation_max,
- GetTensorData<uint8_t>(output), GetTensorDims(output),
- GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter),
+ GetTensorShape(bias), GetTensorData<int32_t>(bias),
+ GetTensorShape(output), GetTensorData<uint8_t>(output),
+ GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context);
break;
+ }
case kGenericOptimized:
case kMultithreadOptimized:
- case kCblasOptimized:
+ case kCblasOptimized: {
// There is only one optimized implementation for Quantized Conv.
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
optimized_ops::Conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- params->stride_width, params->stride_height,
- params->dilation_width_factor, params->dilation_height_factor,
- data->padding.width, data->padding.height, output_offset,
- data->output_multiplier, data->output_shift,
- data->output_activation_min, data->output_activation_max,
- GetTensorData<uint8_t>(output), GetTensorDims(output),
- GetTensorData<uint8_t>(im2col), GetTensorDims(im2col), gemm_context);
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter),
+ GetTensorShape(bias), GetTensorData<int32_t>(bias),
+ GetTensorShape(output), GetTensorData<uint8_t>(output),
+ GetTensorShape(im2col), GetTensorData<uint8_t>(im2col), gemm_context);
break;
+ }
}
}
@@ -467,27 +489,41 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
switch (effective_kernel_type) {
case kReference: {
- reference_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, params->dilation_width_factor,
- params->dilation_height_factor, data->padding.width,
- data->padding.height, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ reference_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kGenericOptimized: {
- optimized_ops::Conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, params->dilation_width_factor,
- params->dilation_height_factor, data->padding.width,
- data->padding.height, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output),
- GetTensorData<float>(im2col), GetTensorDims(im2col));
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ optimized_ops::Conv(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(filter),
+ GetTensorData<float>(filter), GetTensorShape(bias),
+ GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output), GetTensorShape(im2col),
+ GetTensorData<float>(im2col));
break;
}
case kMultithreadOptimized: {
@@ -561,18 +597,27 @@ void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
case kReference:
case kGenericOptimized:
case kMultithreadOptimized:
- case kCblasOptimized:
+ case kCblasOptimized: {
// There is only one implementation for hybrid kernel. Note
// this does not make use of gemmlowp nor supports multithreading.
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
optimized_ops::HybridConv(
- quantized_input_ptr_batch, GetTensorDims(input), filter_ptr,
- GetTensorDims(filter), GetTensorData<float>(bias),
- GetTensorDims(bias), params->stride_width, params->stride_height,
- data->padding.width, data->padding.height, scaling_factors_ptr,
- output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output), im2col_ptr,
- GetTensorDims(im2col));
+ op_params, scaling_factors_ptr, GetTensorShape(input),
+ quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr,
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(output), GetTensorData<float>(output),
+ GetTensorShape(im2col), im2col_ptr);
break;
+ }
}
}
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 411615aa62..f7e6f083ed 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -177,6 +177,30 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) {
}));
}
+TEST_P(ConvolutionOpTest, InputAndFilterSameWidthHeight) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_FLOAT32, {1, 2, 4, 1}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // row = 1
+ -1, -1, 1, 1, // row = 2
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 34}));
+}
+
TEST_P(ConvolutionOpTest, PointwiseFloat32) {
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
{TensorType_FLOAT32, {1, 1, 1, 2}},
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 347515f289..19958844a1 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -126,23 +126,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Matching GetWindowedOutputSize in TensorFlow.
auto padding = params->padding;
- auto compute_out_size = [padding](int imageSize, int filterSize,
- int stride) -> int {
+ auto compute_out_size = [padding](int image_size, int filter_size, int stride,
+ int dilation_rate) -> int {
+ int effective_filter_size = (filter_size - 1) * dilation_rate + 1;
return padding == kTfLitePaddingSame
- ? (imageSize + stride - 1) / stride
+ ? (image_size + stride - 1) / stride
: padding == kTfLitePaddingValid
- ? (imageSize - filterSize + stride) / stride
+ ? (image_size - effective_filter_size + stride) / stride
: 0;
};
- int out_width = compute_out_size(width, filter_width, params->stride_width);
+ int out_width = compute_out_size(width, filter_width, params->stride_width,
+ params->dilation_width_factor);
int out_height =
- compute_out_size(height, filter_height, params->stride_height);
+ compute_out_size(height, filter_height, params->stride_height,
+ params->dilation_height_factor);
- data->padding.height = ComputePadding(params->stride_height, 1, height,
- filter_height, out_height);
+ data->padding.height =
+ ComputePadding(params->stride_height, params->dilation_height_factor,
+ height, filter_height, out_height);
data->padding.width =
- ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+ ComputePadding(params->stride_width, params->dilation_width_factor, width,
+ filter_width, out_width);
// Note that quantized inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
@@ -175,22 +180,31 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
- void (*depthwise_conv)(const float*, const Dims<4>&, const float*,
- const Dims<4>&, const float*, const Dims<4>&, int, int,
- int, int, int, float, float, float*, const Dims<4>&);
+ void (*depthwise_conv)(const DepthwiseParams&, const RuntimeShape&,
+ const float*, const RuntimeShape&, const float*,
+ const RuntimeShape&, const float*, const RuntimeShape&,
+ float*);
if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
}
- depthwise_conv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(filter), GetTensorDims(filter),
- GetTensorData<float>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
- params->depth_multiplier, output_activation_min, output_activation_max,
- GetTensorData<float>(output), GetTensorDims(output));
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ depthwise_conv(op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), GetTensorData<float>(filter),
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(output), GetTensorData<float>(output));
}
template <KernelType kernel_type>
@@ -202,25 +216,38 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
auto filter_offset = -filter->params.zero_point;
auto output_offset = output->params.zero_point;
- void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*,
- const Dims<4>&, int32, const int32*, const Dims<4>&,
- int, int, int, int, int, int32, int32, int, int32,
- int32, uint8*, const Dims<4>&);
+ void (*depthwise_conv)(const DepthwiseParams&, const RuntimeShape&,
+ const uint8*, const RuntimeShape&, const uint8*,
+ const RuntimeShape&, const int32*, const RuntimeShape&,
+ uint8*);
+
if (kernel_type == kReference) {
depthwise_conv = &reference_ops::DepthwiseConv;
} else {
depthwise_conv = &optimized_ops::DepthwiseConv;
}
- depthwise_conv(
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset,
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset,
- GetTensorData<int32_t>(bias), GetTensorDims(bias), params->stride_width,
- params->stride_height, data->padding.width, data->padding.height,
- params->depth_multiplier, output_offset, data->output_multiplier,
- data->output_shift, data->output_activation_min,
- data->output_activation_max, GetTensorData<uint8_t>(output),
- GetTensorDims(output));
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
+ depthwise_conv(op_params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(filter),
+ GetTensorData<uint8_t>(filter), GetTensorShape(bias),
+ GetTensorData<int32_t>(bias), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
}
template <KernelType kernel_type>
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
index c00cafb9fb..4a33a0319d 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc
@@ -14,12 +14,24 @@ limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
#include "tensorflow/contrib/lite/model.h"
namespace tflite {
+
+namespace ops {
+namespace builtin {
+
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT();
+TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT();
+
+} // namespace builtin
+} // namespace ops
+
namespace {
using ::testing::ElementsAreArray;
@@ -28,9 +40,12 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
public:
// TODO(ahentz): Also test different activation types, bias, padding types,
// stride values.
- BaseDepthwiseConvolutionOpModel(const TensorData& input,
+ BaseDepthwiseConvolutionOpModel(TfLiteRegistration* registration,
+ const TensorData& input,
const TensorData& filter,
- const TensorData& output) {
+ const TensorData& output,
+ Padding padding_type,
+ int dilation_factor = 1) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -55,10 +70,14 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel {
SetBuiltinOp(
BuiltinOperator_DEPTHWISE_CONV_2D,
BuiltinOptions_DepthwiseConv2DOptions,
- CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul,
- ActivationFunctionType_NONE)
+ CreateDepthwiseConv2DOptions(builder_, padding_type, 1, 1, depth_mul,
+ ActivationFunctionType_NONE,
+ dilation_factor, dilation_factor)
.Union());
+ resolver_ = absl::make_unique<SingleOpResolver>(
+ BuiltinOperator_DEPTHWISE_CONV_2D, registration);
+
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
@@ -84,10 +103,25 @@ class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel {
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
};
-TEST(DepthwiseConvolutionOpTest, SimpleTest) {
- DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}},
+const auto kKernelMap = new std::map<string, TfLiteRegistration*>({
+ {"Reference", ops::builtin::Register_DEPTHWISE_CONVOLUTION_REF()},
+ {"GenericOptimized",
+ ops::builtin::Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT()},
+ {"NeonOptimized", ops::builtin::Register_DEPTHWISE_CONVOLUTION_NEON_OPT()},
+});
+
+class DepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
+TEST_P(DepthwiseConvolutionOpTest, SimpleTest) {
+ DepthwiseConvolutionOpModel m(GetRegistration(),
+ {TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
- {TensorType_FLOAT32, {}});
+ {TensorType_FLOAT32, {}}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@@ -110,6 +144,94 @@ TEST(DepthwiseConvolutionOpTest, SimpleTest) {
}));
}
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ DepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, Padding_VALID, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
+TEST_P(DepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+ const int depth = 1;
+ const int image_width = 3;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 2;
+ const int filter_count = 1;
+ const int dilation_factor = 2;
+ DepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_FLOAT32,
+ {image_batch_count, image_height, image_width, depth}},
+ {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}},
+ {TensorType_FLOAT32, {}}, Padding_SAME, dilation_factor);
+
+ // The image matrix is:
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+ // The filter matrix is:
+ // | 1 | 2 |
+ // | 3 | 4 |
+ m.SetFilter({1, 2, 3, 4});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Output:
+ // | 4 | 7 | 3 |
+ // | 6 |10 | 4 |
+ // | 2 | 3 | 1 |
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
class QuantizedDepthwiseConvolutionOpModel
: public BaseDepthwiseConvolutionOpModel {
public:
@@ -134,13 +256,20 @@ class QuantizedDepthwiseConvolutionOpModel
}
};
+class QuantizedDepthwiseConvolutionOpTest : public SingleOpTest {
+ protected:
+ const std::map<string, TfLiteRegistration*>& GetKernelMap() override {
+ return *kKernelMap;
+ }
+};
+
// In this test we set the input and output scales so that the results match
// exactly the 'non-quantized' version.
-TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
QuantizedDepthwiseConvolutionOpModel m(
- {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64},
- {TensorType_UINT8, {}, -127, 128});
+ {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
m.SetInput({
1, 2, 7, 8, // column 1
@@ -170,15 +299,16 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) {
}));
}
-TEST(QuantizedDepthwiseConvolutionOpTest,
- SimpleTestQuantizedFilterMultiplierGreaterThan1) {
+TEST_P(QuantizedDepthwiseConvolutionOpTest,
+ SimpleTestQuantizedFilterMultiplierGreaterThan1) {
QuantizedDepthwiseConvolutionOpModel quant_op(
- {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
+ GetRegistration(), {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64},
{TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128},
- {TensorType_UINT8, {}, -127, 128});
- DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}},
+ {TensorType_UINT8, {}, -127, 128}, Padding_VALID);
+ DepthwiseConvolutionOpModel float_op(GetRegistration(),
+ {TensorType_FLOAT32, {1, 3, 2, 2}},
{TensorType_FLOAT32, {1, 2, 2, 4}},
- {TensorType_FLOAT32, {}});
+ {TensorType_FLOAT32, {}}, Padding_VALID);
std::initializer_list<float> input = {
1, 2, 7, 8, // column 1
@@ -207,6 +337,114 @@ TEST(QuantizedDepthwiseConvolutionOpTest,
ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1)));
}
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingValid) {
+ const int depth = 1;
+ const int image_width = 9;
+ const int image_height = 9;
+ const int image_batch_count = 1;
+ const int filter_size = 3;
+ const int filter_count = 1;
+ const int dilation_factor = 3;
+ QuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, Padding_VALID, dilation_factor);
+
+ // The image matrix is:
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 1 | 1 | 1 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+ // clang-format off
+ m.SetInput({0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0});
+ // clang-format on
+ // The filter matrix is:
+ // | 1 | 2 | 3 |
+ // | 4 | 5 | 6 |
+ // | 7 | 8 | 9 |
+ m.SetFilter({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Since the dilation rate is 3 this will reduce the size of the output from
+ // 10x10 to 3x3 of all 5s. Specifically:
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ // | 5 | 5 | 5 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
+}
+
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleDilatedTestPaddingSame) {
+ const int depth = 1;
+ const int image_width = 3;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ const int filter_size = 2;
+ const int filter_count = 1;
+ const int dilation_factor = 2;
+ QuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(),
+ {TensorType_UINT8,
+ {image_batch_count, image_height, image_width, depth},
+ 0,
+ 255},
+ {TensorType_UINT8,
+ {depth, filter_size, filter_size, filter_count},
+ 0,
+ 255},
+ {TensorType_UINT8, {}, 0, 255}, Padding_SAME, dilation_factor);
+
+ // The image matrix is:
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ // | 1 | 1 | 1 |
+ m.SetInput({1, 1, 1, 1, 1, 1, 1, 1, 1});
+ // The filter matrix is:
+ // | 1 | 2 |
+ // | 3 | 4 |
+ m.SetFilter({1, 2, 3, 4});
+ // No bias for this test.
+ m.SetBias({0});
+ m.Invoke();
+
+ // Output:
+ // | 4 | 7 | 3 |
+ // | 6 |10 | 4 |
+ // | 2 | 3 | 1 |
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DepthwiseConvolutionOpTest, DepthwiseConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
+INSTANTIATE_TEST_CASE_P(
+ QuantizedDepthwiseConvolutionOpTest, QuantizedDepthwiseConvolutionOpTest,
+ ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
index d2906632d7..e21dc5ced9 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <string.h>
#include <numeric>
#include <vector>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
index 94c91a6bd6..1e8caebd82 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index 04995d70dd..8c624b3208 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -90,6 +90,10 @@ TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
}
+TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
+ return EvalNumeric(context, node, [](float f) { return f * f; });
+}
+
TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
return EvalLogical(context, node, [](bool v) { return !v; });
}
@@ -129,6 +133,14 @@ TfLiteRegistration* Register_RSQRT() {
return &r;
}
+TfLiteRegistration* Register_SQUARE() {
+ static TfLiteRegistration r = {
+ /*init=*/nullptr, /*free=*/nullptr,
+ elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
+ elementwise::SquareEval};
+ return &r;
+}
+
TfLiteRegistration* Register_LOGICAL_NOT() {
static TfLiteRegistration r = {
/*init=*/nullptr, /*free=*/nullptr,
diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc
index b9d7d73c52..5dd89a0eae 100644
--- a/tensorflow/contrib/lite/kernels/elementwise_test.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc
@@ -92,6 +92,15 @@ TEST(ElementWise, Rsqrt) {
EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
}
+TEST(ElementWise, Square) {
+ ElementWiseOpFloatModel m(BuiltinOperator_SQUARE, {1, 1, 4, 1});
+ m.PopulateTensor<float>(m.input(), {1, 2, 0.5, -3.0});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray(ArrayFloatNear({1, 4.0, 0.25, 9.0})));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1}));
+}
+
TEST(ElementWise, LogicalNot) {
ElementWiseOpBoolModel m(BuiltinOperator_LOGICAL_NOT, {1, 1, 4, 1});
m.PopulateTensor<bool>(m.input(), {true, false, true, false});
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 7a71fcc219..f6d2f76dbe 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -281,15 +281,23 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -filter->params.zero_point;
int32_t output_offset = output->params.zero_point;
-#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
- type::FullyConnected( \
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \
- GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \
- data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<output_data_type>(output), GetTensorDims(output), \
- gemm_context)
+#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.input_offset = input_offset; \
+ op_params.weights_offset = filter_offset; \
+ op_params.output_offset = output_offset; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = -data->output_shift; \
+ op_params.quantized_activation_min = data->output_activation_min; \
+ op_params.quantized_activation_max = data->output_activation_max; \
+ type::FullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<output_data_type>(output), \
+ gemm_context); \
+ }
if (kernel_type == kReference) {
switch (output->type) {
case kTfLiteUInt8:
@@ -349,15 +357,20 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
return kTfLiteError;
}
-#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
- type::ShuffledFullyConnected( \
- GetTensorData<uint8_t>(input), GetTensorDims(input), \
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), \
- GetTensorData<int32_t>(bias), GetTensorDims(bias), \
- data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<int16_t>(output), GetTensorDims(output), \
- GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context)
+#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = -data->output_shift; \
+ op_params.quantized_activation_min = data->output_activation_min; \
+ op_params.quantized_activation_max = data->output_activation_max; \
+ type::ShuffledFullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<int16_t>(output), \
+ GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \
+ }
if (kernel_type == kReference) {
TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
} else {
@@ -376,12 +389,17 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
-#define TF_LITE_FULLY_CONNECTED(type) \
- type::FullyConnected(GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<float>(filter), GetTensorDims(filter), \
- GetTensorData<float>(bias), GetTensorDims(bias), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_FULLY_CONNECTED(type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.float_activation_min = output_activation_min; \
+ op_params.float_activation_max = output_activation_max; \
+ type::FullyConnected(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(filter), \
+ GetTensorData<float>(filter), GetTensorShape(bias), \
+ GetTensorData<float>(bias), GetTensorShape(output), \
+ GetTensorData<float>(output)); \
+ }
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops);
} else if (kernel_type == kPie) {
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index a6fd4ac2dd..afb5ec05df 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -43,6 +43,10 @@ cc_library(
"compatibility.h",
"types.h",
],
+ deps = [
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ "@com_google_absl//absl/base:core_headers",
+ ],
)
config_setting(
@@ -259,6 +263,7 @@ cc_library(
deps = [
":round",
":types",
+ "//tensorflow/contrib/lite/kernels:op_macros",
],
)
@@ -290,7 +295,9 @@ cc_library(
"common.h",
"reference/depthwiseconv_float.h",
"reference/depthwiseconv_uint8.h",
+ "reference/fully_connected.h",
"reference/reference_ops.h",
+ "reference/softmax.h",
],
deps = [
":quantization_util",
@@ -299,6 +306,7 @@ cc_library(
":types",
"@gemmlowp",
"//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:op_macros",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -319,8 +327,10 @@ cc_library(
"common.h",
"reference/depthwiseconv_float.h",
"reference/depthwiseconv_uint8.h",
+ "reference/fully_connected.h",
"reference/legacy_reference_ops.h",
"reference/reference_ops.h",
+ "reference/softmax.h",
],
deps = [
":quantization_util",
@@ -329,6 +339,7 @@ cc_library(
":types",
"@gemmlowp",
"//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:op_macros",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -458,9 +469,10 @@ cc_library(
],
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
- "//tensorflow/contrib/lite/kernels:activation_functor",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
+ "//tensorflow/contrib/lite/kernels:op_macros",
"@gemmlowp",
] + select({
":arm": [
diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h
index 93fc6b6a76..b87cf2b60d 100644
--- a/tensorflow/contrib/lite/kernels/internal/compatibility.h
+++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h
@@ -15,65 +15,65 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
-#include <cassert>
#include <cstdint>
-#include <cstdlib>
+
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
#ifndef TFLITE_DCHECK
-#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false)
+#define TFLITE_DCHECK(condition) (condition) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_EQ
-#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_NE
-#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_GE
-#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_GT
-#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_LE
-#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
#ifndef TFLITE_DCHECK_LT
-#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false)
+#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : TFLITE_ASSERT_FALSE
#endif
// TODO(ahentz): Clean up: We should stick to the DCHECK versions.
#ifndef TFLITE_CHECK
-#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort()
+#define TFLITE_CHECK(condition) (condition) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_EQ
-#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_NE
-#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_GE
-#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_GT
-#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_LE
-#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : TFLITE_ABORT
#endif
#ifndef TFLITE_CHECK_LT
-#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort()
+#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : TFLITE_ABORT
#endif
// TODO(ahentz): Clean up.
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
index 844ee6a53d..41862a21a6 100644
--- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/test_util.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -28,23 +29,21 @@ namespace tflite {
namespace {
// Runs the DepthwiseConv and compares against the reference implementation.
-template <FusedActivationFunctionType Ac>
-void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride, int pad_width, int pad_height,
- int depth_multiplier, const Dims<4>& output_dims) {
- const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
+void TestOneDepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape) {
+ const int output_buffer_size = output_shape.FlatSize();
std::vector<float> output_data(output_buffer_size);
std::vector<float> reference_output_data(output_buffer_size);
- reference_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
- filter_dims, bias_data, bias_dims, stride,
- pad_width, pad_height, depth_multiplier,
- reference_output_data.data(), output_dims);
- optimized_ops::DepthwiseConv<Ac>(input_data, input_dims, filter_data,
- filter_dims, bias_data, bias_dims, stride,
- pad_width, pad_height, depth_multiplier,
- output_data.data(), output_dims);
+ reference_ops::DepthwiseConv(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ reference_output_data.data());
+ optimized_ops::DepthwiseConv(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data.data());
+
double sum_abs_diff = 0;
float max_abs_val = 0;
for (int i = 0; i < output_buffer_size; i++) {
@@ -59,27 +58,6 @@ void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
-void TestOneDepthwiseConv(FusedActivationFunctionType Ac,
- const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride, int pad_width, int pad_height,
- int depth_multiplier, const Dims<4>& output_dims) {
-#define TOCO_HANDLE_CASE(AC_TYPE) \
- if (AC_TYPE == Ac) { \
- TestOneDepthwiseConv<AC_TYPE>(input_data, input_dims, filter_data, \
- filter_dims, bias_data, bias_dims, stride, \
- pad_width, pad_height, depth_multiplier, \
- output_dims); \
- return; \
- }
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1)
- TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6)
-#undef TOCO_HANDLE_CASE
-}
-
// This function picks some random DepthwiseConv params, which may or may not
// be legal. If they're not legal, it returns false. If they're legal,
// it runs the DepthwiseConv test and returns true. This allows the caller
@@ -99,6 +77,16 @@ bool TryTestOneDepthwiseConv() {
const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
const int output_depth = input_depth * depth_multiplier;
+ const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ float output_activation_min, output_activation_max;
+ FusedActivationFunctionType ac =
+ RandomElement(std::vector<FusedActivationFunctionType>(
+ {FusedActivationFunctionType::kNone,
+ FusedActivationFunctionType::kRelu,
+ FusedActivationFunctionType::kRelu1,
+ FusedActivationFunctionType::kRelu6}));
+ GetActivationMinMax(ac, &output_activation_min, &output_activation_max);
// The optimized DepthwiseConv implementation currently uses a fixed-size
// accumulator buffer on the stack, with that size. This currently means
// that it does not support larger output depths. It CHECK's for it,
@@ -109,27 +97,23 @@ bool TryTestOneDepthwiseConv() {
if (output_depth > kMaxSupportedOutputDepth) {
return false;
}
- const auto ac = RandomElement(std::vector<FusedActivationFunctionType>(
- {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu,
- FusedActivationFunctionType::kRelu6,
- FusedActivationFunctionType::kRelu1}));
- Dims<4> input_dims_inference =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- Dims<4> output_dims_inference;
+ RuntimeShape input_shape_inference(
+ {batch, input_height, input_width, input_depth});
+ RuntimeShape output_shape_inference;
int pad_width, pad_height;
const auto padding_type =
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
- if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
- filter_height, stride, padding_type,
- &output_dims_inference, &pad_width, &pad_height)) {
+ if (!ComputeConvSizes(input_shape_inference, output_depth, filter_width,
+ filter_height, stride, dilation_width_factor,
+ dilation_height_factor, padding_type,
+ &output_shape_inference, &pad_width, &pad_height)) {
return false;
}
- Dims<4> filter_dims_inference =
- MakeDimsForInference(output_depth, filter_width, filter_height, 1);
- Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1);
- const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
- const int filter_buffer_size =
- RequiredBufferSizeForDims(filter_dims_inference);
+ RuntimeShape filter_shape_inference(
+ {1, filter_height, filter_width, output_depth});
+ RuntimeShape bias_shape_inference({1, 1, 1, output_depth});
+ const int input_buffer_size = input_shape_inference.FlatSize();
+ const int filter_buffer_size = filter_shape_inference.FlatSize();
std::vector<float> input_data(input_buffer_size);
std::vector<float> filter_data(filter_buffer_size);
std::vector<float> bias_data(output_depth);
@@ -140,10 +124,21 @@ bool TryTestOneDepthwiseConv() {
FillRandom(&input_data, -input_amplitude, input_amplitude);
FillRandom(&filter_data, -filter_amplitude, filter_amplitude);
FillRandom(&bias_data, -bias_amplitude, bias_amplitude);
- TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
- filter_data.data(), filter_dims_inference,
- bias_data.data(), bias_dims_inference, stride, pad_width,
- pad_height, depth_multiplier, output_dims_inference);
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride;
+ op_params.stride_height = stride;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ TestOneDepthwiseConv(op_params, input_shape_inference, input_data.data(),
+ filter_shape_inference, filter_data.data(),
+ bias_shape_inference, bias_data.data(),
+ output_shape_inference);
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
index 2c0fc8433e..9414e109c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc
@@ -35,29 +35,40 @@ namespace {
// Runs the DepthwiseConv and compares against the reference implementation.
template <FusedActivationFunctionType Ac>
int TestOneDepthwiseConvWithGivenOutputShift(
- const std::uint8_t* input_data, const Dims<4>& input_dims,
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
std::int32_t input_offset, const std::uint8_t* filter_data,
- const Dims<4>& filter_dims, std::int32_t filter_offset,
- const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
int pad_width, int pad_height, int depth_multiplier,
std::int32_t output_offset, std::int32_t output_multiplier,
int output_shift, std::int32_t output_activation_min,
- std::int32_t output_activation_max, const Dims<4>& output_dims) {
- const int output_buffer_size = RequiredBufferSizeForDims(output_dims);
+ std::int32_t output_activation_max, const RuntimeShape& output_shape) {
+ const int output_buffer_size = output_shape.FlatSize();
std::vector<std::uint8_t> output_data(output_buffer_size);
std::vector<std::uint8_t> reference_output_data(output_buffer_size);
- reference_ops::DepthwiseConv<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
- depth_multiplier, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max,
- reference_output_data.data(), output_dims);
- optimized_ops::DepthwiseConv<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
- depth_multiplier, output_offset, output_multiplier, output_shift,
- output_activation_min, output_activation_max, output_data.data(),
- output_dims);
+
+ tflite::DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride;
+ op_params.stride_height = stride;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = -output_shift;
+ reference_ops::DepthwiseConv(op_params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ reference_output_data.data());
+ optimized_ops::DepthwiseConv(op_params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data.data());
int saturated_min = 0;
int saturated_max = 0;
std::vector<int> diff(output_buffer_size);
@@ -106,25 +117,25 @@ int TestOneDepthwiseConvWithGivenOutputShift(
// vacuous. So we just bisect our way to reasonable output_shift values.
template <FusedActivationFunctionType Ac>
void TestOneDepthwiseConvBisectOutputShift(
- const std::uint8_t* input_data, const Dims<4>& input_dims,
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
std::int32_t input_offset, const std::uint8_t* filter_data,
- const Dims<4>& filter_dims, std::int32_t filter_offset,
- const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
int pad_width, int pad_height, int depth_multiplier,
std::int32_t output_offset, std::int32_t output_multiplier,
int output_activation_bisect_start, int output_activation_bisect_end,
std::int32_t output_activation_min, std::int32_t output_activation_max,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
ASSERT_LT(output_activation_bisect_start, output_activation_bisect_end)
<< "Bisection failed ?!?!";
int output_shift_bisect_midpoint =
(output_activation_bisect_start + output_activation_bisect_end) / 2;
int bisect_result = TestOneDepthwiseConvWithGivenOutputShift<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
depth_multiplier, output_offset, output_multiplier,
output_shift_bisect_midpoint, output_activation_min,
- output_activation_max, output_dims);
+ output_activation_max, output_shape);
// At this point we know that the test succeeded (otherwise it would have
// aborted).
if (bisect_result == 0) {
@@ -147,47 +158,47 @@ void TestOneDepthwiseConvBisectOutputShift(
? output_activation_bisect_end
: output_shift_bisect_midpoint;
TestOneDepthwiseConvBisectOutputShift<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
depth_multiplier, output_offset, output_multiplier,
new_output_activation_bisect_start, new_output_activation_bisect_end,
- output_activation_min, output_activation_max, output_dims);
+ output_activation_min, output_activation_max, output_shape);
}
template <FusedActivationFunctionType Ac>
void TestOneDepthwiseConv(
- const std::uint8_t* input_data, const Dims<4>& input_dims,
+ const std::uint8_t* input_data, const RuntimeShape& input_shape,
std::int32_t input_offset, const std::uint8_t* filter_data,
- const Dims<4>& filter_dims, std::int32_t filter_offset,
- const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride,
+ const RuntimeShape& filter_shape, std::int32_t filter_offset,
+ const std::int32_t* bias_data, const RuntimeShape& bias_shape, int stride,
int pad_width, int pad_height, int depth_multiplier,
std::int32_t output_offset, std::int32_t output_multiplier,
std::int32_t output_activation_min, std::int32_t output_activation_max,
- const Dims<4>& output_dims) {
+ const RuntimeShape& output_shape) {
TestOneDepthwiseConvBisectOutputShift<Ac>(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height,
+ input_data, input_shape, input_offset, filter_data, filter_shape,
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height,
depth_multiplier, output_offset, output_multiplier, 0, 32,
- output_activation_min, output_activation_max, output_dims);
+ output_activation_min, output_activation_max, output_shape);
}
void TestOneDepthwiseConv(
FusedActivationFunctionType Ac, const std::uint8_t* input_data,
- const Dims<4>& input_dims, std::int32_t input_offset,
- const std::uint8_t* filter_data, const Dims<4>& filter_dims,
+ const RuntimeShape& input_shape, std::int32_t input_offset,
+ const std::uint8_t* filter_data, const RuntimeShape& filter_shape,
std::int32_t filter_offset, const std::int32_t* bias_data,
- const Dims<4>& bias_dims, int stride, int pad_width, int pad_height,
+ const RuntimeShape& bias_shape, int stride, int pad_width, int pad_height,
int depth_multiplier, std::int32_t output_offset,
std::int32_t output_multiplier, std::int32_t output_activation_min,
- std::int32_t output_activation_max, const Dims<4>& output_dims) {
-#define TOCO_HANDLE_CASE(AC_TYPE) \
- if (AC_TYPE == Ac) { \
- TestOneDepthwiseConv<AC_TYPE>( \
- input_data, input_dims, input_offset, filter_data, filter_dims, \
- filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, \
- depth_multiplier, output_offset, output_multiplier, \
- output_activation_min, output_activation_max, output_dims); \
- return; \
+ std::int32_t output_activation_max, const RuntimeShape& output_shape) {
+#define TOCO_HANDLE_CASE(AC_TYPE) \
+ if (AC_TYPE == Ac) { \
+ TestOneDepthwiseConv<AC_TYPE>( \
+ input_data, input_shape, input_offset, filter_data, filter_shape, \
+ filter_offset, bias_data, bias_shape, stride, pad_width, pad_height, \
+ depth_multiplier, output_offset, output_multiplier, \
+ output_activation_min, output_activation_max, output_shape); \
+ return; \
}
TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone)
TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu)
@@ -199,6 +210,7 @@ void TestOneDepthwiseConv(
bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
int input_height, int filter_width, int filter_height,
int depth_multiplier, int stride,
+ int dilation_width_factor, int dilation_height_factor,
PaddingType padding_type) {
const int output_depth = input_depth * depth_multiplier;
// The optimized DepthwiseConv implementation currently uses a fixed-size
@@ -226,33 +238,33 @@ bool TryTestDepthwiseConv(int batch, int input_depth, int input_width,
const std::int32_t input_offset = UniformRandomInt(-256, 0);
const std::int32_t filter_offset = UniformRandomInt(-256, 0);
const std::int32_t output_offset = UniformRandomInt(-256, 0);
- Dims<4> input_dims_inference =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- Dims<4> output_dims_inference;
+ RuntimeShape input_shape_inference(
+ {batch, input_height, input_width, input_depth});
+ RuntimeShape output_shape_inference;
int pad_width, pad_height;
- if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width,
- filter_height, stride, padding_type,
- &output_dims_inference, &pad_width, &pad_height)) {
+ if (!ComputeConvSizes(input_shape_inference, output_depth, filter_width,
+ filter_height, stride, dilation_width_factor,
+ dilation_height_factor, padding_type,
+ &output_shape_inference, &pad_width, &pad_height)) {
return false;
}
- Dims<4> filter_dims_inference =
- MakeDimsForInference(output_depth, filter_width, filter_height, 1);
- Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1);
- const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
- const int filter_buffer_size =
- RequiredBufferSizeForDims(filter_dims_inference);
+ RuntimeShape filter_shape_inference(
+ {1, filter_height, filter_width, output_depth});
+ RuntimeShape bias_shape_inference({1, 1, 1, output_depth});
+ const int input_buffer_size = input_shape_inference.FlatSize();
+ const int filter_buffer_size = filter_shape_inference.FlatSize();
std::vector<std::uint8_t> input_data(input_buffer_size);
std::vector<std::uint8_t> filter_data(filter_buffer_size);
std::vector<std::int32_t> bias_data(output_depth);
FillRandom(&input_data);
FillRandom(&filter_data);
FillRandom(&bias_data, -10000, 10000);
- TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference,
- input_offset, filter_data.data(), filter_dims_inference,
- filter_offset, bias_data.data(), bias_dims_inference,
+ TestOneDepthwiseConv(ac, input_data.data(), input_shape_inference,
+ input_offset, filter_data.data(), filter_shape_inference,
+ filter_offset, bias_data.data(), bias_shape_inference,
stride, pad_width, pad_height, depth_multiplier,
output_offset, output_multiplier, output_activation_min,
- output_activation_max, output_dims_inference);
+ output_activation_max, output_shape_inference);
return true;
}
@@ -274,12 +286,15 @@ bool TryTestOneDepthwiseConv() {
const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10);
const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50);
const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8);
+ const int dilation_width_factor = RandomElement(std::vector<int>({1, 2, 4}));
+ const int dilation_height_factor = RandomElement(std::vector<int>({1, 2, 4}));
const auto padding_type =
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
filter_width, filter_height, depth_multiplier,
- stride, padding_type);
+ stride, dilation_width_factor,
+ dilation_height_factor, padding_type);
}
// Tests parameters for the 3x3 filter kernel.
@@ -292,6 +307,9 @@ bool TryTestOneDepthwiseConv3x3Filter() {
const int filter_height = 3;
const int depth_multiplier = 1;
const int stride = UniformRandomInt(1, 2);
+ // We don't support dilations in the 3x3 filter.
+ const int dilation_width_factor = 1;
+ const int dilation_height_factor = 1;
// Although the kernel supports only kValid padding, we test that kSame
// is using the correct code path.
const auto padding_type =
@@ -299,7 +317,8 @@ bool TryTestOneDepthwiseConv3x3Filter() {
return TryTestDepthwiseConv(batch, input_depth, input_width, input_height,
filter_width, filter_height, depth_multiplier,
- stride, padding_type);
+ stride, dilation_width_factor,
+ dilation_height_factor, padding_type);
}
void TestOneDepthwiseConv() {
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
index 4a90e7e640..40d42bbae9 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h
@@ -49,9 +49,18 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
filter_width != 1 || filter_height != 1;
if (need_im2col) {
TFLITE_DCHECK(im2col_data);
- optimized_ops::Im2col(input_data, input_dims, stride_width, stride_height,
- pad_width, pad_height, filter_height, filter_width, 0,
- im2col_data, im2col_dims);
+ ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ optimized_ops::Im2col(op_params, filter_height, filter_width, 0,
+ DimsToShape(input_dims), input_data,
+ DimsToShape(im2col_dims), im2col_data);
+
gemm_input_data = im2col_data;
gemm_input_dims = &im2col_dims;
} else {
@@ -82,8 +91,8 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
stride_a, b, stride_b, 0.0f, c, stride_c);
optimized_ops::AddBiasAndEvalActivationFunction(
- bias_data, bias_dims, output_data, output_dims, output_activation_min,
- output_activation_max);
+ output_activation_min, output_activation_max, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
}
} // namespace cblas_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
index 7f6eea2d5d..114575a96a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h
@@ -761,7 +761,8 @@ struct FloatDepthwiseConvKernel<true, 4, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
+void FloatDepthwiseConvAccumRow(int stride, int dilation_factor,
+ int input_depth, int input_width,
const float* input_data, int pad_width,
int depth_multiplier, int filter_width,
const float* filter_data,
@@ -835,10 +836,10 @@ void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width,
// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized.
inline void FloatDepthwiseConvAccumRowGeneric(
- int stride, int input_depth, int input_width, const float* input_data,
- int pad_width, int depth_multiplier, int filter_width,
- const float* filter_data, int out_x_buffer_start, int out_x_buffer_end,
- int output_depth, float* acc_buffer) {
+ int stride, int dilation_factor, int input_depth, int input_width,
+ const float* input_data, int pad_width, int depth_multiplier,
+ int filter_width, const float* filter_data, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth, float* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -860,6 +861,7 @@ inline void FloatDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
+ << "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@@ -869,14 +871,17 @@ inline void FloatDepthwiseConvAccumRowGeneric(
const float* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
- out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
- const int out_x_loop_end =
- std::min(out_x_buffer_end,
- (pad_width + input_width - filter_x + stride - 1) / stride);
+ out_x_buffer_start,
+ (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+ const int out_x_loop_end = std::min(
+ out_x_buffer_end,
+ (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+ stride);
float* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
- const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const int in_x_origin =
+ (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const float* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -907,25 +912,37 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
}
}
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConv");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
static const int kAccBufferMaxSize = 2048;
float acc_buffer[kAccBufferMaxSize];
@@ -946,7 +963,8 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
- depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER && \
+ dilation_height_factor == 1 && dilation_width_factor == 1) { \
row_accum_func = \
FloatDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@@ -990,14 +1008,22 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
row_accum_func = FloatDepthwiseConvAccumRowGeneric;
}
+ const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
+ const int input_batch_stride = input_height_stride * input_shape.Dims(1);
+ const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
+
// Now that we have determined row_accum_func, we can start work.
float* output_ptr = output_data;
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
- const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_start =
+ std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(filter_height,
+ (input_height - in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@@ -1013,14 +1039,13 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
- const int in_y = in_y_origin + filter_y;
- row_accum_func(stride_width, input_depth, input_width,
- input_data + in_y * input_dims.strides[2] +
- b * input_dims.strides[3],
- pad_width, depth_multiplier, filter_width,
- filter_data + filter_y * filter_dims.strides[2],
- out_x_buffer_start, out_x_buffer_end, output_depth,
- acc_buffer);
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
+ row_accum_func(
+ stride_width, dilation_width_factor, input_depth, input_width,
+ input_data + in_y * input_height_stride + b * input_batch_stride,
+ pad_width, depth_multiplier, filter_width,
+ filter_data + filter_y * filter_height_stride, out_x_buffer_start,
+ out_x_buffer_end, output_depth, acc_buffer);
}
// Finished accumulating. Now store to destination.
const int num_output_values = output_depth * num_output_pixels;
@@ -1067,6 +1092,51 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
@@ -1083,6 +1153,7 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
output_data, output_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index 3fd00c8930..f892b8f661 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -24,6 +24,9 @@ limitations under the License.
namespace tflite {
namespace optimized_ops {
+// TODO(b/80418076): Move to legacy ops file, along with invocations.
+static constexpr int kDepthwiseReverseShift = -1;
+
// Implementation of quantized DepthwiseConv
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
@@ -1466,11 +1469,14 @@ struct QuantizedDepthwiseConvKernel<false, 12, 1> {
// Accumulates the effect of one row of the filter, on a segment of one row
// of the output, accessing the corresponding one row of the input.
template <bool kAllowStrided, int kFixedInputDepth, int kFixedDepthMultiplier>
-void QuantizedDepthwiseConvAccumRow(
- int stride, int input_depth, int input_width, const uint8* input_data,
- int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
- const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
- int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor,
+ int input_depth, int input_width,
+ const uint8* input_data, int16 input_offset,
+ int pad_width, int depth_multiplier,
+ int filter_width, const uint8* filter_data,
+ int16 filter_offset, int out_x_buffer_start,
+ int out_x_buffer_end, int output_depth,
+ int32* acc_buffer) {
#ifdef GEMMLOWP_PROFILING
gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
#endif
@@ -1537,10 +1543,11 @@ void QuantizedDepthwiseConvAccumRow(
// generic fallback of DepthwiseConvAccumRow, portable, non-templatized.
inline void QuantizedDepthwiseConvAccumRowGeneric(
- int stride, int input_depth, int input_width, const uint8* input_data,
- int16 input_offset, int pad_width, int depth_multiplier, int filter_width,
- const uint8* filter_data, int16 filter_offset, int out_x_buffer_start,
- int out_x_buffer_end, int output_depth, int32* acc_buffer) {
+ int stride, int dilation_factor, int input_depth, int input_width,
+ const uint8* input_data, int16 input_offset, int pad_width,
+ int depth_multiplier, int filter_width, const uint8* filter_data,
+ int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end,
+ int output_depth, int32* acc_buffer) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)");
#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
@@ -1562,6 +1569,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
<< "* stride = " << stride << "\n"
<< "* input_depth = " << input_depth << "\n"
<< "* depth_multiplier = " << depth_multiplier << "\n"
+ << "* dilation_factor = " << dilation_factor << "\n"
<< "*\n"
<< "* Please do not hesitate to contact benoitjacob@ with this\n"
<< "* information.\n"
@@ -1571,14 +1579,17 @@ inline void QuantizedDepthwiseConvAccumRowGeneric(
const uint8* filter_base_ptr = filter_data;
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int out_x_loop_start = std::max(
- out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride);
- const int out_x_loop_end =
- std::min(out_x_buffer_end,
- (pad_width + input_width - filter_x + stride - 1) / stride);
+ out_x_buffer_start,
+ (pad_width - dilation_factor * filter_x + stride - 1) / stride);
+ const int out_x_loop_end = std::min(
+ out_x_buffer_end,
+ (pad_width + input_width - dilation_factor * filter_x + stride - 1) /
+ stride);
int32* acc_buffer_ptr =
acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth;
- const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x;
+ const int in_x_origin =
+ (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x;
const uint8* input_ptr = input_data + in_x_origin * input_depth;
const int input_ptr_increment = (stride - 1) * input_depth;
for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) {
@@ -1669,33 +1680,46 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth,
}
}
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit");
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
-
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
#ifdef USE_NEON
- const bool shift_left = (output_shift <= 0);
- const int32 multiplier_power_of_two = shift_left ? (1 << -output_shift) : 1;
+ const bool shift_left = (output_shift > 0);
+ const int32 multiplier_power_of_two = shift_left ? (1 << output_shift) : 1;
#endif
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
// Jetson TX-2. This compiler does not support the offsetof() macro.
@@ -1703,14 +1727,12 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Call kernel optimized for depthwise convolutions using 3x3 filters if
// parameters are supported.
if (Fast3x3FilterKernelSupported(
- input_dims, filter_dims, stride_width, stride_height, pad_width,
- pad_height, depth_multiplier, output_dims, output_shift)) {
- DepthwiseConv3x3Filter(input_data, input_dims, input_offset, filter_data,
- filter_dims, filter_offset, bias_data, bias_dims,
- stride_width, stride_height, pad_width, pad_height,
- depth_multiplier, output_offset, output_multiplier,
- output_shift, output_activation_min,
- output_activation_max, output_data, output_dims);
+ input_shape, filter_shape, stride_width, stride_height,
+ dilation_width_factor, dilation_height_factor, pad_width, pad_height,
+ depth_multiplier, output_shape, output_shift)) {
+ DepthwiseConv3x3Filter(params, input_shape, input_data, filter_shape,
+ filter_data, bias_shape, bias_data, output_shape,
+ output_data);
return;
}
#endif
@@ -1734,7 +1756,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
FIXED_DEPTH_MULTIPLIER) \
if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \
(input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \
- depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \
+ depth_multiplier == FIXED_DEPTH_MULTIPLIER && \
+ dilation_width_factor == 1 && dilation_height_factor == 1) { \
row_accum_func = \
QuantizedDepthwiseConvAccumRow<ALLOW_STRIDED, FIXED_INPUT_DEPTH, \
FIXED_DEPTH_MULTIPLIER>; \
@@ -1785,14 +1808,22 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
#undef TFMINI_USE_DEPTHWISECONV_KERNEL
+ const int input_height_stride = input_shape.Dims(3) * input_shape.Dims(2);
+ const int input_batch_stride = input_height_stride * input_shape.Dims(1);
+ const int filter_height_stride = filter_shape.Dims(3) * filter_shape.Dims(2);
+
// Now that we have determined row_accum_func, we can start work.
uint8* output_ptr = output_data;
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
const int in_y_origin = (out_y * stride_height) - pad_height;
- const int filter_y_start = std::max(0, -in_y_origin);
+ const int filter_y_start =
+ std::max(0, (-in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
const int filter_y_end =
- std::min(filter_height, input_height - in_y_origin);
+ std::min(filter_height,
+ (input_height - in_y_origin + dilation_height_factor - 1) /
+ dilation_height_factor);
for (int out_x_buffer_start = 0; out_x_buffer_start < output_width;
out_x_buffer_start += kOutputPixelsInAccBuffer) {
const int out_x_buffer_end = std::min(
@@ -1808,13 +1839,12 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Accumulation loop. Most of the time should be spent in here.
for (int filter_y = filter_y_start; filter_y < filter_y_end;
++filter_y) {
- const int in_y = in_y_origin + filter_y;
+ const int in_y = in_y_origin + dilation_height_factor * filter_y;
row_accum_func(
- stride_width, input_depth, input_width,
- input_data + in_y * input_dims.strides[2] +
- b * input_dims.strides[3],
+ stride_width, dilation_width_factor, input_depth, input_width,
+ input_data + in_y * input_height_stride + b * input_batch_stride,
input_offset, pad_width, depth_multiplier, filter_width,
- filter_data + filter_y * filter_dims.strides[2], filter_offset,
+ filter_data + filter_y * filter_height_stride, filter_offset,
out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer);
}
// Finished accumulating int32 values. Now need to convert them to
@@ -1845,7 +1875,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier);
}
for (int j = 0; j < 4; j++) {
- acc[j] = RoundingDivideByPOT(acc[j], output_shift);
+ acc[j] = RoundingDivideByPOT(acc[j], -output_shift);
}
} else {
// Fixed-point multiplication.
@@ -1889,8 +1919,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
acc0 = vqrdmulhq_n_s32(acc0, output_multiplier);
acc1 = vqrdmulhq_n_s32(acc1, output_multiplier);
// Rounding right shift.
- acc0 = RoundingDivideByPOT(acc0, output_shift);
- acc1 = RoundingDivideByPOT(acc1, output_shift);
+ acc0 = RoundingDivideByPOT(acc0, -output_shift);
+ acc1 = RoundingDivideByPOT(acc1, -output_shift);
} else {
// Fixed-point multiplication.
acc0 = vmulq_n_s32(acc0, multiplier_power_of_two);
@@ -1926,7 +1956,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
// Fixed-point multiplication.
acc = vqrdmulhq_n_s32(acc, output_multiplier);
// Rounding right shift.
- acc = RoundingDivideByPOT(acc, output_shift);
+ acc = RoundingDivideByPOT(acc, -output_shift);
} else {
// Fixed-point multiplication.
acc = vmulq_n_s32(acc, multiplier_power_of_two);
@@ -1953,7 +1983,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
for (; i < num_output_values; i++) {
int32 acc = acc_buffer[i];
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
@@ -1964,6 +1994,62 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
@@ -1987,6 +2073,7 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
output_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 0ce64f8c70..4809ddd02a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -49,7 +49,7 @@ struct DepthwiseConvParams {
int32 output_multiplier;
int32 output_activation_min;
int32 output_activation_max;
- int32 output_shift;
+ int32 output_right_shift;
int32 input_width;
int32 input_height;
int32 stride_width;
@@ -75,7 +75,7 @@ struct DepthwiseConvParams {
#define OFFSET_OUTPUT_MULTIPLIER 52
#define OFFSET_OUTPUT_ACTIVATION_MIN 56
#define OFFSET_OUTPUT_ACTIVATION_MAX 60
-#define OFFSET_OUTPUT_SHIFT 64
+#define OFFSET_OUTPUT_RIGHT_SHIFT 64
#define OFFSET_INPUT_WIDTH 68
#define OFFSET_INPUT_HEIGHT 72
#define OFFSET_STRIDE_WIDTH 76
@@ -105,8 +105,8 @@ static_assert(offsetof(DepthwiseConvParams, output_activation_min) ==
OFFSET_OUTPUT_ACTIVATION_MIN, "");
static_assert(offsetof(DepthwiseConvParams, output_activation_max) ==
OFFSET_OUTPUT_ACTIVATION_MAX, "");
-static_assert(offsetof(DepthwiseConvParams, output_shift) ==
- OFFSET_OUTPUT_SHIFT, "");
+static_assert(offsetof(DepthwiseConvParams, output_right_shift) ==
+ OFFSET_OUTPUT_RIGHT_SHIFT, "");
static_assert(offsetof(DepthwiseConvParams, input_width) ==
OFFSET_INPUT_WIDTH, "");
static_assert(offsetof(DepthwiseConvParams, input_height) ==
@@ -189,7 +189,7 @@ struct DepthwiseConvWindow<8, 1, 1> {
"ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n"
"ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w9\n"
- "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v29.4s, w2\n"
"ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"dup v30.4s, w4\n"
@@ -1166,7 +1166,7 @@ struct DepthwiseConvWindow<8, 2, 2> {
// values from time to time when there are not enough NEON registers.
// We use x9--x15 general purpose registers as they are caller-saved
// temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT
- "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n"
"cmp %w[output_window_height], #2\n"
"dup v28.8h, w0\n"
@@ -2216,7 +2216,7 @@ struct DepthwiseConvPartial<EdgeType::kCenter, 1, 1> {
"dup v27.4s, w10\n"
"ld1 {v0.8b}, [%[filter_ptr]], #8\n"
"cmp x11, #16\n"
- "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w9\n"
"ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w10, w10\n"
@@ -2355,7 +2355,7 @@ struct DepthwiseConvPartial<EdgeType::kCorner, 1, 1> {
"dup v26.8h, w6\n"
"ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w7\n"
- "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w6\n"
"ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w7, w7\n"
@@ -2532,7 +2532,7 @@ struct DepthwiseConvPartial<EdgeType::kHorizontal, 1, 1> {
"dup v26.8h, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w13\n"
- "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w13, w13\n"
@@ -2739,7 +2739,7 @@ struct DepthwiseConvPartial<EdgeType::kVertical, 1, 1> {
"dup v26.8h, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n"
"dup v27.4s, w13\n"
- "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n"
+ "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_RIGHT_SHIFT) "]\n"
"dup v28.4s, w12\n"
"ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n"
"neg w13, w13\n"
@@ -2910,7 +2910,7 @@ struct DepthwiseConvPartial<EdgeType::kVertical, 1, 1> {
#undef OFFSET_OUTPUT_MULTIPLIER
#undef OFFSET_OUTPUT_ACTIVATION_MIN
#undef OFFSET_OUTPUT_ACTIVATION_MAX
-#undef OFFSET_OUTPUT_SHIFT
+#undef OFFSET_OUTPUT_RIGHT_SHIFT
#undef OFFSET_INPUT_WIDTH
#undef OFFSET_INPUT_HEIGHT
#undef OFFSET_OUTPUT_WIDTH
@@ -3175,16 +3175,18 @@ inline void DepthwiseConvHandlePadding(const uint8* input_data,
}
inline bool Fast3x3FilterKernelSupported(
- const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width,
- int32 stride_height, int32 pad_width, int32 pad_height,
- int32 depth_multiplier, const Dims<4>& output_dims, int32 output_shift) {
- const int32 input_height = ArraySize(input_dims, 2);
- const int32 input_width = ArraySize(input_dims, 1);
- const int32 input_depth = ArraySize(input_dims, 0);
- const int32 filter_height = ArraySize(filter_dims, 2);
- const int32 filter_width = ArraySize(filter_dims, 1);
- const int32 output_height = ArraySize(output_dims, 2);
- const int32 output_width = ArraySize(output_dims, 1);
+ const RuntimeShape& input_shape, const RuntimeShape& filter_shape,
+ int32 stride_width, int32 stride_height, int32 dilation_width_factor,
+ int32 dilation_height_factor, int32 pad_width, int32 pad_height,
+ int32 depth_multiplier, const RuntimeShape& output_shape,
+ int32 output_shift) {
+ const int32 input_height = input_shape.Dims(1);
+ const int32 input_width = input_shape.Dims(2);
+ const int32 input_depth = input_shape.Dims(3);
+ const int32 filter_height = filter_shape.Dims(1);
+ const int32 filter_width = filter_shape.Dims(2);
+ const int32 output_height = output_shape.Dims(1);
+ const int32 output_width = output_shape.Dims(2);
bool supported =
filter_width == 3 && filter_height == 3 && depth_multiplier == 1 &&
@@ -3192,7 +3194,8 @@ inline bool Fast3x3FilterKernelSupported(
(stride_height == 1 || stride_height == 2) &&
(stride_width == stride_height) && (pad_width == 0 || pad_width == 1) &&
(pad_height == 0 || pad_height == 1) && (pad_width == pad_height) &&
- (input_depth % 8) == 0 && (output_shift > 0);
+ (input_depth % 8) == 0 && (output_shift <= 0) &&
+ dilation_width_factor == 1 && dilation_height_factor == 1;
if (!supported) {
return false;
@@ -3234,36 +3237,47 @@ inline bool Fast3x3FilterKernelSupported(
}
inline void DepthwiseConv3x3Filter(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int32 stride_width,
- int32 stride_height, int32 pad_width, int32 pad_height,
- int32 depth_multiplier, int32 output_offset, int32 output_multiplier,
- int32 output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+ const DepthwiseParams& rt_params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__);
DepthwiseConvParams params;
- params.input_depth = ArraySize(input_dims, 0);
- params.input_width = ArraySize(input_dims, 1);
- params.input_height = ArraySize(input_dims, 2);
+
+ const int32 stride_width = rt_params.stride_width;
+ const int32 stride_height = rt_params.stride_height;
+ const int32 pad_width = rt_params.padding_values.width;
+ const int32 pad_height = rt_params.padding_values.height;
+ const int32 depth_multiplier = rt_params.depth_multiplier;
+ const int32 output_activation_min = rt_params.quantized_activation_min;
+ const int32 output_activation_max = rt_params.quantized_activation_max;
+ const int32 input_offset = rt_params.input_offset;
+ const int32 filter_offset = rt_params.weights_offset;
+ const int32 output_offset = rt_params.output_offset;
+ const int32 output_multiplier = rt_params.output_multiplier;
+ const int32 output_shift = rt_params.output_shift;
+
+ params.input_depth = input_shape.Dims(3);
+ params.input_width = input_shape.Dims(2);
+ params.input_height = input_shape.Dims(1);
params.input_row_size = params.input_depth * params.input_width;
params.input_offset = input_offset;
params.stride_width = stride_width;
params.stride_height = stride_height;
- params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- params.output_width = ArraySize(output_dims, 1);
- params.output_height = ArraySize(output_dims, 2);
+ params.output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ params.output_width = output_shape.Dims(2);
+ params.output_height = output_shape.Dims(1);
params.output_row_size = params.output_depth * params.output_width;
params.output_offset = output_offset;
params.filter_offset = filter_offset;
params.output_multiplier = output_multiplier;
- params.output_shift = output_shift;
+ params.output_right_shift = -output_shift;
params.output_activation_min = output_activation_min;
params.output_activation_max = output_activation_max;
- const int32 filter_height = ArraySize(filter_dims, 2);
- const int32 filter_width = ArraySize(filter_dims, 1);
+ const int32 filter_height = filter_shape.Dims(1);
+ const int32 filter_width = filter_shape.Dims(2);
params.filter_row_size = params.output_depth * filter_width;
// Algorithm assumes below constraints. It is optimized for depth
@@ -3279,7 +3293,7 @@ inline void DepthwiseConv3x3Filter(
TFLITE_DCHECK(pad_width == 0 || pad_width == 1);
TFLITE_DCHECK(pad_width == pad_height);
- const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
+ const int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
const int64_t input_batch_size = params.input_row_size * params.input_height;
const int64_t output_batch_size =
params.output_row_size * params.output_height;
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 5fb31889fe..b5d001cc9e 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -113,8 +113,8 @@ class EigenTensorConvFunctor {
filter_width * filter_height * input_depth;
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
dim_pair[0] = Eigen::IndexPair<Eigen::DenseIndex>(1, 0);
- EigenMatrix output(output_data, 1, filter_count);
- ConstEigenMatrix input(input_data, 1, k);
+ EigenMatrix output(output_data, input_batches, filter_count);
+ ConstEigenMatrix input(input_data, input_batches, k);
ConstEigenMatrix filter(filter_data, k, filter_count);
MatMulConvFunctor<Eigen::ThreadPoolDevice, T>()(device, output, input,
filter, dim_pair);
@@ -157,8 +157,8 @@ inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data,
output_width);
optimized_ops::AddBiasAndEvalActivationFunction(
- bias_data, bias_dims, output_data, output_dims, output_activation_min,
- output_activation_max);
+ output_activation_min, output_activation_max, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
}
} // namespace multithreaded_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 27418178fd..36c15dbc57 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -457,7 +457,7 @@ void NeonSymmetricQuantizeFloats(const float* values, const int size,
return;
}
*scaling_factor = range / kScale;
- const float scaling_factor_inv = 1.0f / *scaling_factor;
+ const float scaling_factor_inv = kScale / range;
const int postamble_start =
size - (size & (2 * kFloatWeightsPerNeonLane - 1));
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 2c8e8f90e3..0999738396 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -81,20 +81,16 @@ using reference_ops::Select;
using reference_ops::SpaceToBatchND;
using reference_ops::Split;
using reference_ops::StridedSlice;
+using reference_ops::TensorFlowSplit;
using reference_ops::Transpose;
// TODO(b/80247582) Remove this constant.
// This will be phased out as the shifts are revised with more thought. Use of a
// constant enables us to track progress on this work.
//
-// Used mainly to convert from old-style shifts (right) to new-style (left).
+// Used to convert from old-style shifts (right) to new-style (left).
static constexpr int kReverseShift = -1;
-inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
- return RuntimeShape(
- {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
-}
-
// Make a local VectorMap typedef allowing to map a float array
// as a Eigen vector expression. The std::conditional here is to
// construct the suitable Eigen type for the constness of the
@@ -188,6 +184,15 @@ ArrayMap<Scalar> MapAsArrayWithFirstDimAsRows(Scalar* data,
return ArrayMap<Scalar>(data, rows, cols);
}
+template <typename Scalar>
+ArrayMap<Scalar> MapAsArrayWithLastDimAsRows(Scalar* data,
+ const RuntimeShape& shape) {
+ const int dims_count = shape.DimensionsCount();
+ const int rows = shape.Dims(dims_count - 1);
+ const int cols = FlatSizeSkipDim(shape, dims_count - 1);
+ return ArrayMap<Scalar>(data, rows, cols);
+}
+
// Copied from tensorflow/core/framework/tensor_types.h
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
struct TTypes {
@@ -200,6 +205,8 @@ struct TTypes {
UnalignedConstMatrix;
};
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
// TODO(b/62193649): this function is only needed as long
// as we have the --variable_batch hack.
template <typename Scalar, int N>
@@ -212,6 +219,18 @@ MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
return MatrixMap<Scalar>(data, rows, cols);
}
+// TODO(b/62193649): this function is only needed as long
+// as we have the --variable_batch hack.
+template <typename Scalar>
+MatrixMap<Scalar> MapAsMatrixWithGivenNumberOfRows(Scalar* data,
+ const RuntimeShape& shape,
+ int rows) {
+ const int flatsize = shape.FlatSize();
+ TFLITE_DCHECK_EQ(flatsize % rows, 0);
+ const int cols = flatsize / rows;
+ return MatrixMap<Scalar>(data, rows, cols);
+}
+
// This is like the template-parameter version, except that the power-of-two is
// passed as a function parameter. The template version is to be preferred,
// since some target hardware optimizations depend on the range of the exponent.
@@ -260,16 +279,16 @@ inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) {
return true;
}
-inline void AddBiasAndEvalActivationFunction(const float* bias_data,
- const Dims<4>& bias_dims,
- float* array_data,
- const Dims<4>& array_dims,
- float output_activation_min,
- float output_activation_max) {
+inline void AddBiasAndEvalActivationFunction(float output_activation_min,
+ float output_activation_max,
+ const RuntimeShape& bias_shape,
+ const float* bias_data,
+ const RuntimeShape& array_shape,
+ float* array_data) {
#ifdef USE_NEON
gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
- const int bias_size = FlatSize(bias_dims);
- const int array_size = FlatSize(array_dims);
+ const int bias_size = bias_shape.FlatSize();
+ const int array_size = array_shape.FlatSize();
TFLITE_DCHECK_EQ((array_size % bias_size), 0);
float* array_ptr = array_data;
float* array_end_ptr = array_ptr + array_size;
@@ -319,8 +338,8 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
}
#else // not NEON
gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction");
- const int bias_size = FlatSize(bias_dims);
- const int array_size = FlatSize(array_dims);
+ const int bias_size = bias_shape.FlatSize();
+ const int array_size = array_shape.FlatSize();
TFLITE_DCHECK_EQ((array_size % bias_size), 0);
for (int array_offset = 0; array_offset < array_size;
array_offset += bias_size) {
@@ -333,6 +352,19 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
#endif
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void AddBiasAndEvalActivationFunction(const float* bias_data,
+ const Dims<4>& bias_dims,
+ float* array_data,
+ const Dims<4>& array_dims,
+ float output_activation_min,
+ float output_activation_max) {
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(array_dims), array_data);
+}
+
// Note: This to be converted to RuntimeShapes along with Conv.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
@@ -380,21 +412,24 @@ inline void optimized_ops_preload_l1_keep(const uint8* ptr) {
// to a matrix*vector product. LSTM cells contain a fully-connected node;
// when quantized, this becomes a special type of GEMV operation where
// the output is 16bit-quantized, thus needs its own special path.
-inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
- const uint8* weights_data,
- const Dims<4>& weights_dims,
- uint8 weights_zero_point, const int32* bias_data,
- const Dims<4>& bias_dims, int32 accum_multiplier,
- int accum_shift, int16* output_data,
- const Dims<4>& output_dims) {
+inline void GEMVForLstmCell(const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& weights_shape,
+ const uint8* weights_data, uint8 weights_zero_point,
+ const RuntimeShape& bias_shape,
+ const int32* bias_data, int32 accum_multiplier,
+ int accum_shift, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
// This special fast path for quantized LSTM cells does not try to support
// odd sizes that we haven't encountered in any LSTM cell, that would
// require special code (that would go untested until any LSTM cell
@@ -567,18 +602,21 @@ inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims,
#ifdef GEMMLOWP_NEON
inline void GEMVForLstmCellWithSymmetricRange(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 accum_multiplier,
- int accum_shift, int16* output_data, const Dims<4>& output_dims) {
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& weights_shape, const uint8* weights_data,
+ const RuntimeShape& bias_shape, const int32* bias_data,
+ int32 accum_multiplier, int accum_shift, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("GEMVForLstmCellWithSymmetricRange");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
// This special fast path for quantized LSTM cells does not try to support
// odd sizes that we haven't encountered in any LSTM cell, that would
// require special code (that would go untested until any LSTM cell
@@ -854,14 +892,16 @@ inline void GEMVForLstmCellWithSymmetricRange(
}
#endif
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("FullyConnected");
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+
// TODO(b/62193649): this convoluted shape computation (determining
// input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows)
// is because the current --variable_batch hack consists in overwriting the
@@ -870,18 +910,38 @@ inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
// When that is fixed, this should become:
// const auto input_matrix_map =
// MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- const int input_rows = ArraySize(weights_dims, 0);
+ const int dims_count = weights_shape.DimensionsCount();
+ const int input_rows = weights_shape.Dims(dims_count - 1);
const auto input_matrix_map =
- MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows);
+ MapAsMatrixWithGivenNumberOfRows(input_data, input_shape, input_rows);
const auto filter_matrix_map =
- MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims);
+ MapAsMatrixWithLastDimAsRows(weights_data, weights_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -899,22 +959,25 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims,
#ifdef USE_NEON
inline void FullyConnectedAsGEMV(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ int32 input_offset, const RuntimeShape& filter_shape,
+ const uint8* filter_data, int32 filter_offset,
+ const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset,
int32 output_multiplier, int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+ int32 output_activation_max, const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_dims, 0), 1);
- const int input_size = FlatSizeSkipDim(input_dims, 3);
- const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(output_shape, output_dim_count - 1), 1);
+ const int input_size = FlatSizeSkipDim(input_shape, 0);
+ const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
static constexpr int kPeel = 4;
- const bool shift_left = (output_shift <= 0);
+ const bool shift_left = (output_shift > 0);
for (int k = 0; k < input_size; k += 64) {
optimized_ops_preload_l1_stream(input_data + k);
}
@@ -1027,7 +1090,7 @@ inline void FullyConnectedAsGEMV(
bias_ptr += 4;
reduced = vaddq_s32(reduced, bias_vec);
if (shift_left) {
- const int32 multiplier_power_of_two = 1 << -output_shift;
+ const int32 multiplier_power_of_two = 1 << output_shift;
reduced = vmulq_n_s32(reduced, multiplier_power_of_two);
reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
} else {
@@ -1035,7 +1098,7 @@ inline void FullyConnectedAsGEMV(
reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
// Rounding-shift-right.
using gemmlowp::RoundingDivideByPOT;
- reduced = RoundingDivideByPOT(reduced, output_shift);
+ reduced = RoundingDivideByPOT(reduced, -output_shift);
}
// Add the output offset.
const int32x4_t output_offset_vec = vdupq_n_s32(output_offset);
@@ -1083,42 +1146,47 @@ struct GemmlowpOutputPipeline {
}
};
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit");
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
#ifdef USE_NEON
- const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0);
+ const int output_size = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
if (batches == 1 && !(output_size % 4)) {
return FullyConnectedAsGEMV(
- input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset, output_multiplier,
- output_shift, output_activation_min, output_activation_max, output_data,
- output_dims);
+ input_shape, input_data, input_offset, filter_shape, filter_data,
+ filter_offset, bias_shape, bias_data, output_offset, output_multiplier,
+ output_shift, output_activation_min, output_activation_max,
+ output_shape, output_data);
}
#endif // USE_NEON
- const int filter_rows = filter_dims.sizes[1];
- const int filter_cols = filter_dims.sizes[0];
- TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1);
- const int output_rows = output_dims.sizes[0];
+ const int filter_rows = filter_shape.Dims(filter_dim_count - 2);
+ const int filter_cols = filter_shape.Dims(filter_dim_count - 1);
+ TFLITE_DCHECK_EQ(filter_shape.FlatSize(), filter_rows * filter_cols);
+ const int output_rows = output_shape.Dims(output_dim_count - 1);
TFLITE_DCHECK_EQ(output_rows, filter_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
filter_data, output_rows, filter_cols, filter_cols);
@@ -1127,7 +1195,7 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, batches, output_rows);
const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
- bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -1135,30 +1203,66 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
input_offset, output_pipeline);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
inline void FullyConnected(
- const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
- const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
- int32 output_multiplier, int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data_int32, const RuntimeShape& output_shape,
+ int16* output_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("FullyConnected/Uint8Int16");
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
(void)gemm_context; // only used in properly optimized code.
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
TFLITE_DCHECK_EQ(output_offset, 0);
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
// Implementation of the fully connected node suited to the inside of an LSTM
// cell. The operands are 8-bit integers, the accumulators are internally
@@ -1169,17 +1273,17 @@ inline void FullyConnected(
if (batches == 1 && input_offset == -128 && output_activation_min == -32768 &&
output_activation_max == 32767) {
if (filter_offset == -128 && !(output_depth % 4) && !(accum_depth % 64)) {
- GEMVForLstmCellWithSymmetricRange(input_data, input_dims, filter_data,
- filter_dims, bias_data_int32, bias_dims,
- output_multiplier, -output_shift,
- output_data, output_dims);
+ GEMVForLstmCellWithSymmetricRange(
+ input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data_int32, output_multiplier, output_shift, output_shape,
+ output_data);
return;
}
if (!(output_depth % 4) && !(accum_depth % 8)) {
- GEMVForLstmCell(input_data, input_dims, filter_data, filter_dims,
- filter_offset, bias_data_int32, bias_dims,
- output_multiplier, -output_shift, output_data,
- output_dims);
+ GEMVForLstmCell(input_shape, input_data, filter_shape, filter_data,
+ filter_offset, bias_shape, bias_data_int32,
+ output_multiplier, output_shift, output_shape,
+ output_data);
return;
}
}
@@ -1199,7 +1303,7 @@ inline void FullyConnected(
scale_stage.result_offset_after_shift = 0;
scale_stage.result_fixedpoint_multiplier = output_multiplier;
// Note that this shift is negated wrt ordinary FC.
- scale_stage.result_exponent = -output_shift;
+ scale_stage.result_exponent = output_shift;
gemmlowp::OutputStageClamp clamp_stage;
clamp_stage.min = output_activation_min;
clamp_stage.max = output_activation_max;
@@ -1213,6 +1317,32 @@ inline void FullyConnected(
input_offset, output_pipeline);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims, int32 input_offset,
+ const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset,
+ int32 output_multiplier, int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data, const Dims<4>& output_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
@@ -1248,8 +1378,8 @@ inline void ShuffledFullyConnectedWorkerImpl(
#if defined USE_NEON
const int8* shuffled_weights_ptr = shuffled_weights_data;
if (batches == 1) {
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ const int right_shift = output_shift > 0 ? 0 : -output_shift;
+ const int left_shift = output_shift > 0 ? output_shift : 0;
for (int c = 0; c < output_depth; c += 4) {
// Accumulation loop.
int32x4_t row_accum0 = vdupq_n_s32(0);
@@ -1315,8 +1445,8 @@ inline void ShuffledFullyConnectedWorkerImpl(
vst1_s16(output_data + c, res16);
}
} else if (batches == 4) {
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ const int right_shift = output_shift > 0 ? 0 : -output_shift;
+ const int left_shift = output_shift > 0 ? output_shift : 0;
for (int c = 0; c < output_depth; c += 4) {
const int8* shuffled_input_ptr =
reinterpret_cast<const int8*>(shuffled_input_workspace_data);
@@ -1447,8 +1577,8 @@ inline void ShuffledFullyConnectedWorkerImpl(
// (16-bit, typically 3 integer bits) fixed-point format. The quantized
// multiplier and shift here have been pre-computed offline
// (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ acc =
+ MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, -32768);
acc = std::min(acc, 32767);
@@ -1499,7 +1629,7 @@ inline void ShuffledFullyConnectedWorkerImpl(
// quantized multiplier and shift here have been pre-computed offline
// (e.g. by toco).
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
// Saturate, cast to int16, and store to output array.
acc = std::max(acc, -32768);
acc = std::min(acc, 32767);
@@ -1555,26 +1685,34 @@ struct ShuffledFullyConnectedWorkerTask : gemmlowp::Task {
};
inline void ShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& weights_shape,
+ const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, uint8* shuffled_input_workspace_data,
+ gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("ShuffledFullyConnected/8bit");
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
(void)gemm_context; // only used in optimized code.
TFLITE_DCHECK_EQ(output_activation_min, -32768);
TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
// but the current --variable_batch hack consists in overwriting the 3rd
// dimension with the runtime batch size, as we don't keep track for each
// array of which dimension is the batch dimension in it.
- const int batches = FlatSizeSkipDim(output_dims, 0);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
TFLITE_DCHECK((accum_depth % 16) == 0);
TFLITE_DCHECK((output_depth % 4) == 0);
// Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
@@ -1671,13 +1809,40 @@ inline void ShuffledFullyConnected(
gemm_context->workers_pool()->Execute(tasks);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
template <typename T>
-inline void ExtractPatchIntoBufferColumn(
- const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
- int stride_width, int stride_height, int pad_width, int pad_height,
- int in_width, int in_height, int in_depth, int single_buffer_length,
- int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) {
+inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
+ int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height,
+ int pad_width, int pad_height,
+ int in_width, int in_height,
+ int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data,
+ T* conv_buffer_data, uint8 zero_byte) {
gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn");
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
// This chunk of code reshapes all the inputs corresponding to
// output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
const int kwidth_times_indepth = kwidth * in_depth;
@@ -1699,7 +1864,7 @@ inline void ExtractPatchIntoBufferColumn(
const int output_row_offset = (buffer_id * single_buffer_length);
int out_offset =
output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
- int in_offset = Offset(input_dims, 0, iw_start, ih_start, b);
+ int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
// Express all of the calculations as padding around the input patch.
const int top_padding = h_offset;
@@ -1713,7 +1878,7 @@ inline void ExtractPatchIntoBufferColumn(
// patch that are off the edge of the input image.
if (top_padding > 0) {
const int top_row_elements = (top_padding * kwidth * in_depth);
- memset(conv_buffer_data + output_row_offset, byte_zero,
+ memset(conv_buffer_data + output_row_offset, zero_byte,
(top_row_elements * sizeof(T)));
}
@@ -1730,14 +1895,14 @@ inline void ExtractPatchIntoBufferColumn(
for (int ih = ih_start; ih < ih_end; ++ih) {
if (left_padding > 0) {
const int left_start = (out_offset - (left_padding * in_depth));
- memset(conv_buffer_data + left_start, byte_zero,
+ memset(conv_buffer_data + left_start, zero_byte,
(left_padding * in_depth * sizeof(T)));
}
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
single_row_num * sizeof(T));
if (right_padding > 0) {
const int right_start = (out_offset + single_row_num);
- memset(conv_buffer_data + right_start, byte_zero,
+ memset(conv_buffer_data + right_start, zero_byte,
(right_padding * in_depth * sizeof(T)));
}
out_offset += kwidth_times_indepth;
@@ -1752,61 +1917,64 @@ inline void ExtractPatchIntoBufferColumn(
const int bottom_start =
output_row_offset +
((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
- memset(conv_buffer_data + bottom_start, byte_zero,
+ memset(conv_buffer_data + bottom_start, zero_byte,
(bottom_row_elements * sizeof(T)));
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <typename T>
-void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 byte_zero,
- T* im2col_data) {
+inline void ExtractPatchIntoBufferColumn(
+ const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth,
+ int stride_width, int stride_height, int pad_width, int pad_height,
+ int in_width, int in_height, int in_depth, int single_buffer_length,
+ int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) {
+ ExtractPatchIntoBufferColumn(
+ DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width,
+ stride_height, pad_width, pad_height, in_width, in_height, in_depth,
+ single_buffer_length, buffer_id, in_data, conv_buffer_data, zero_byte);
+}
+
+template <typename T>
+void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& filter_shape,
+ const RuntimeShape& output_shape, T* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
// For dilated convolution, the input pixels are not contiguous therefore we
// can't use the same opitimizations as Im2Col(). Though note this code would
// work fine for the non-dilated case too (though likely a bit slower).
gemmlowp::ScopedProfilingLabel label("DilatedIm2col");
TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
TFLITE_DCHECK(im2col_data);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- MatchingArraySize(output_dims, 0, filter_dims, 3);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ MatchingDim(output_shape, 3, filter_shape, 0);
// Construct the MxN sized im2col matrix.
// The rows M, are sub-ordered B x H x W
- Dims<4> row_dims;
- row_dims.sizes[0] = output_width;
- row_dims.sizes[1] = output_height;
- row_dims.sizes[2] = batches;
- row_dims.sizes[3] = 1;
- ComputeStrides(&row_dims);
-
+ const RuntimeShape row_shape({1, batches, output_height, output_width});
// The columns, N, are sub-ordered Kh x Kw x Din
- Dims<4> col_dims;
- col_dims.sizes[0] = input_depth;
- col_dims.sizes[1] = filter_width;
- col_dims.sizes[2] = filter_height;
- col_dims.sizes[3] = 1;
- ComputeStrides(&col_dims);
-
+ const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
// Use dimensions M and N to construct dims for indexing directly into im2col
- Dims<4> im2col_dims;
- im2col_dims.sizes[0] = FlatSize(col_dims);
- im2col_dims.sizes[1] = FlatSize(row_dims);
- im2col_dims.sizes[2] = 1;
- im2col_dims.sizes[3] = 1;
- ComputeStrides(&im2col_dims);
+ const RuntimeShape im2col_shape(
+ {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
// Loop through the output rows (B x H x W)
for (int batch = 0; batch < batches; ++batch) {
@@ -1814,7 +1982,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
for (int out_x = 0; out_x < output_width; ++out_x) {
// Each im2col row is an output pixel. Arrange the input data in this
// row in an order we can conveniently multiply with the filter data.
- int row_offset = Offset(row_dims, out_x, out_y, batch, 0);
+ int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
const int in_x_origin = (out_x * stride_width) - pad_width;
const int in_y_origin = (out_y * stride_height) - pad_height;
// Loop through all the pixels of the filter (Kh x Kw)
@@ -1825,25 +1993,25 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
// Loop through all the filter pixels in this row.
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
const int in_x = in_x_origin + dilation_width_factor * filter_x;
- int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0);
+ int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
T* dst = im2col_data +
- Offset(im2col_dims, col_offset, row_offset, 0, 0);
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
if ((in_x >= 0) && (in_x < input_width)) {
// Filter pixel is within the input, copy the input data.
T const* src =
- input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ input_data + Offset(input_shape, batch, in_y, in_x, 0);
memcpy(dst, src, input_depth * sizeof(T));
} else {
// Filter pixel is outside the input, zero it out.
- memset(dst, byte_zero, input_depth * sizeof(T));
+ memset(dst, zero_byte, input_depth * sizeof(T));
}
}
} else {
// Filter row is outside the input, zero out the entire filter row.
- int col_offset = Offset(col_dims, 0, 0, filter_y, 0);
- T* dst =
- im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0);
- memset(dst, byte_zero, filter_width * input_depth * sizeof(T));
+ int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
+ T* dst = im2col_data +
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
+ memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
}
}
}
@@ -1851,21 +2019,49 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <typename T>
-void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
- int stride_height, int pad_width, int pad_height, int kheight,
- int kwidth, uint8 byte_zero, T* output_data,
- const Dims<4>& output_dims) {
+void DilatedIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+
+ DilatedIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+template <typename T>
+void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Im2col");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
int buffer_id = 0;
// Loop over the output nodes.
@@ -1873,93 +2069,154 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
ExtractPatchIntoBufferColumn(
- input_dims, w, h, b, kheight, kwidth, stride_width, stride_height,
+ input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
pad_width, pad_height, input_width, input_height, input_depth,
- output_depth, buffer_id, input_data, output_data, byte_zero);
+ output_depth, buffer_id, input_data, output_data, zero_byte);
++buffer_id;
}
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height, int kheight,
+ int kwidth, uint8 zero_byte, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+
+ Im2col(op_params, kheight, kwidth, zero_byte, DimsToShape(input_dims),
+ input_data, DimsToShape(output_dims), output_data);
+}
+
// legacy, for compatibility with old checked-in code
template <typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
int pad_width, int pad_height, int kheight, int kwidth,
- uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, byte_zero, output_data, output_dims);
+ kwidth, zero_byte, output_data, output_dims);
}
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
(void)im2col_data;
- (void)im2col_dims;
+ (void)im2col_shape;
gemmlowp::ScopedProfilingLabel label("Conv");
// NB: static_cast<float>(0x00000000h) == 0.0f
const uint8 float_zero_byte = 0x00;
const float* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_dilated_im2col =
dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
filter_width != 1 || filter_height != 1;
if (need_dilated_im2col) {
- DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, dilation_width_factor, dilation_height_factor,
- pad_width, pad_height, output_dims, float_zero_byte,
- im2col_data);
+ DilatedIm2col(params, float_zero_byte, input_shape, input_data,
+ filter_shape, output_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, float_zero_byte,
- im2col_data, im2col_dims);
+ Im2col(params, filter_height, filter_width, float_zero_byte, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
// TODO(aselle): We need to make sure to not send im2col if it is not
// needed.
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
const auto im2col_matrix_map =
- MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims);
+ MapAsMatrixWithLastDimAsRows(gemm_input_data, *gemm_input_shape);
const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
}
-inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
- const int8_t* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* scaling_factors_ptr,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- int8_t* im2col_data, const Dims<4>& im2col_dims) {
- const int batch_size = input_dims.sizes[3];
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
+ const RuntimeShape& input_shape,
+ const int8_t* input_data,
+ const RuntimeShape& filter_shape,
+ const int8_t* filter_data,
+ const RuntimeShape& bias_shape, const float* bias_data,
+ const RuntimeShape& output_shape, float* output_data,
+ const RuntimeShape& im2col_shape, int8_t* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batch_size = input_shape.Dims(0);
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const int8_t* gemm_input_data = nullptr;
int num_input;
@@ -1970,25 +2227,22 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
TFLITE_DCHECK(im2col_data);
// symmetric quantization assumes zero point of 0.
const int input_zero_point = 0;
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, input_zero_point,
- im2col_data, im2col_dims);
+
+ Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- num_input = im2col_dims.sizes[0] * im2col_dims.sizes[1] *
- im2col_dims.sizes[2] * im2col_dims.sizes[3];
+ num_input = im2col_shape.FlatSize();
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- num_input = input_dims.sizes[0] * input_dims.sizes[1] *
- input_dims.sizes[2] * input_dims.sizes[3];
+ num_input = input_shape.FlatSize();
}
// Flatten 4D matrices into 2D matrices for matrix multiplication.
// Flatten so that each filter has its own row.
- const int filter_rows = filter_dims.sizes[3];
- const int filter_cols =
- filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+ const int filter_rows = filter_shape.Dims(0);
+ const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
// In MatrixBatchVectorMultiplyAccumulate, each output value is the
// dot product of one row of the first matrix with one row of the second
@@ -1998,15 +2252,11 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
const int gemm_input_cols = filter_cols;
const int gemm_input_rows = num_input / gemm_input_cols;
- const int output_cols = output_dims.sizes[0];
- const int output_rows =
- output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ const int output_cols = output_shape.Dims(3);
+ const int output_rows = FlatSizeSkipDim(output_shape, 3);
TFLITE_DCHECK_EQ(output_cols, filter_rows);
TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_cols);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_cols);
// MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
// input matrix has its own scale factor. This code duplicates the scale
@@ -2023,11 +2273,39 @@ inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
/*result_stride=*/1);
- AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
- output_dims, output_activation_min,
- output_activation_max);
+ AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
+ bias_shape, bias_data, output_shape,
+ output_data);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+ const int8_t* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* scaling_factors_ptr,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ int8_t* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ HybridConv(op_params, scaling_factors_ptr, DimsToShape(input_dims),
+ input_data, DimsToShape(filter_dims), filter_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
@@ -2045,6 +2323,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
im2col_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -2061,6 +2340,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -2074,27 +2354,32 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
output_dims, im2col_data, im2col_dims);
}
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- uint8* im2col_data, const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, const RuntimeShape& im2col_shape,
+ uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label("Conv/8bit");
-
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
const uint8* gemm_input_data = nullptr;
- const Dims<4>* gemm_input_dims = nullptr;
- const int filter_width = ArraySize(filter_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
+ const RuntimeShape* gemm_input_shape = nullptr;
+ const int filter_width = filter_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
const bool need_dilated_im2col =
dilation_width_factor != 1 || dilation_height_factor != 1;
const bool need_im2col = stride_width != 1 || stride_height != 1 ||
@@ -2104,53 +2389,47 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
const int input_zero_point = -input_offset;
TFLITE_DCHECK_GE(input_zero_point, 0);
TFLITE_DCHECK_LE(input_zero_point, 255);
- DilatedIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, dilation_width_factor, dilation_height_factor,
- pad_width, pad_height, output_dims, input_zero_point,
- im2col_data);
+ DilatedIm2col(params, input_zero_point, input_shape, input_data,
+ filter_shape, output_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else if (need_im2col) {
TFLITE_DCHECK(im2col_data);
const int input_zero_point = -input_offset;
TFLITE_DCHECK_GE(input_zero_point, 0);
TFLITE_DCHECK_LE(input_zero_point, 255);
- Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
- pad_height, filter_height, filter_width, input_zero_point,
- im2col_data, im2col_dims);
+ Im2col(params, filter_height, filter_width, input_zero_point, input_shape,
+ input_data, im2col_shape, im2col_data);
gemm_input_data = im2col_data;
- gemm_input_dims = &im2col_dims;
+ gemm_input_shape = &im2col_shape;
} else {
TFLITE_DCHECK(!im2col_data);
gemm_input_data = input_data;
- gemm_input_dims = &input_dims;
+ gemm_input_shape = &input_shape;
}
- const int gemm_input_rows = gemm_input_dims->sizes[0];
+ const int gemm_input_rows = gemm_input_shape->Dims(3);
// Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784).
// The root cause has not yet been identified though. Same applies below for
// the other calls commented out. This is a partial rollback of cl/196819423.
- // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0);
- const int gemm_input_cols = gemm_input_dims->sizes[1] *
- gemm_input_dims->sizes[2] *
- gemm_input_dims->sizes[3];
- const int filter_rows = filter_dims.sizes[3];
+ // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_shape, 3);
+ const int gemm_input_cols = gemm_input_shape->Dims(0) *
+ gemm_input_shape->Dims(1) *
+ gemm_input_shape->Dims(2);
+ const int filter_rows = filter_shape.Dims(0);
// See b/79927784.
- // const int filter_cols = FlatSizeSkipDim(filter_dims, 3);
+ // const int filter_cols = FlatSizeSkipDim(filter_shape, 0);
const int filter_cols =
- filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
- const int output_rows = output_dims.sizes[0];
+ filter_shape.Dims(1) * filter_shape.Dims(2) * filter_shape.Dims(3);
+ const int output_rows = output_shape.Dims(3);
// See b/79927784.
- // const int output_cols = FlatSizeSkipDim(output_dims, 0);
+ // const int output_cols = FlatSizeSkipDim(output_shape, 3);
const int output_cols =
- output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ output_shape.Dims(0) * output_shape.Dims(1) * output_shape.Dims(2);
TFLITE_DCHECK_EQ(output_rows, filter_rows);
TFLITE_DCHECK_EQ(output_cols, gemm_input_cols);
TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows);
- TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
- TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::RowMajor> filter_matrix(
filter_data, filter_rows, filter_cols);
gemmlowp::MatrixMap<const uint8, gemmlowp::MapOrder::ColMajor> input_matrix(
@@ -2158,7 +2437,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
gemmlowp::MatrixMap<uint8, gemmlowp::MapOrder::ColMajor> output_matrix(
output_data, output_rows, output_cols);
const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp(
- bias_data, output_rows, output_offset, output_multiplier, -output_shift,
+ bias_data, output_rows, output_offset, output_multiplier, output_shift,
output_activation_min, output_activation_max);
gemmlowp::GemmWithOutputPipeline<uint8, uint8,
gemmlowp::L8R8WithLhsNonzeroBitDepthParams>(
@@ -2166,6 +2445,44 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
input_offset, output_pipeline);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
@@ -2184,6 +2501,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2213,6 +2531,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -2236,13 +2555,14 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac, typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
int pad_width, int pad_height, int kheight, int kwidth,
- uint8 byte_zero, T* output_data, const Dims<4>& output_dims) {
+ uint8 zero_byte, T* output_data, const Dims<4>& output_dims) {
Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight,
- kwidth, byte_zero, output_data, output_dims);
+ kwidth, zero_byte, output_data, output_dims);
}
// legacy, for compatibility with old checked-in code
@@ -2266,6 +2586,7 @@ void ConvAsGemm(const float* input_data, const Dims<4>& input_dims,
output_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
@@ -2320,9 +2641,9 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
@@ -2361,9 +2682,9 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_depth = output_shape.Dims(3);
@@ -2472,6 +2793,7 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ // Convert right shift (right is positive) to left shift.
*output_shift *= kReverseShift;
}
@@ -3191,7 +3513,7 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params,
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -3316,62 +3638,96 @@ void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
}
}
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+ const float* prev_state_data,
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
gemmlowp::ScopedProfilingLabel label("LstmCell");
- MatchingArraySize( // batches
- input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims,
- 3, output_activ_dims, 3);
- MatchingArraySize( // height
- input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims,
- 2, output_activ_dims, 2);
- MatchingArraySize( // width
- input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims,
- 1, output_activ_dims, 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ MatchingDim( // batches
+ input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+ output_state_shape, 0, output_activ_shape, 0);
+ MatchingDim( // height
+ input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+ output_state_shape, 1, output_activ_shape, 1);
+ MatchingDim( // width
+ input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+ output_state_shape, 2, output_activ_shape, 2);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
// Concatenate prev_activ and input data together
std::vector<float const*> concat_input_arrays_data;
- std::vector<Dims<4> const*> concat_input_arrays_dims;
+ std::vector<RuntimeShape const*> concat_input_arrays_shapes;
concat_input_arrays_data.push_back(input_data);
concat_input_arrays_data.push_back(prev_activ_data);
- concat_input_arrays_dims.push_back(&input_dims);
- concat_input_arrays_dims.push_back(&prev_activ_dims);
- Concatenation<FusedActivationFunctionType::kNone, float>(
- 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
- concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+ concat_input_arrays_shapes.push_back(&input_shape);
+ concat_input_arrays_shapes.push_back(&prev_activ_shape);
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = concat_input_arrays_data.size();
+ Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+ &(concat_input_arrays_data[0]), concat_temp_shape,
+ concat_temp_data);
// Fully connected
- FullyConnected<FusedActivationFunctionType::kNone>(
- concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
- bias_dims, activ_temp_data, activ_temp_dims);
+ tflite::FullyConnectedParams fc_params;
+ fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+ fc_params.float_activation_max = std::numeric_limits<float>::max();
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+ weights_data, bias_shape, bias_data, activ_temp_shape,
+ activ_temp_data);
// Map raw arrays to Eigen arrays so we can use Eigen's optimized array
// operations.
ArrayMap<float> activ_temp_map =
- MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims);
+ MapAsArrayWithLastDimAsRows(activ_temp_data, activ_temp_shape);
auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth,
activ_temp_map.cols());
auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth,
@@ -3381,11 +3737,11 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth,
activ_temp_map.cols());
ArrayMap<const float> prev_state_map =
- MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims);
+ MapAsArrayWithLastDimAsRows(prev_state_data, prev_state_shape);
ArrayMap<float> output_state_map =
- MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims);
+ MapAsArrayWithLastDimAsRows(output_state_data, output_state_shape);
ArrayMap<float> output_activ_map =
- MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims);
+ MapAsArrayWithLastDimAsRows(output_activ_data, output_activ_shape);
// Combined memory state and final output calculation
gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput");
@@ -3399,56 +3755,120 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
output_state_map.tanh();
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
// Quantized LSTM cell. Currently just a copy of the reference impl in
// reference_ops.h. See the big function comment there, not replicating it
// here.
template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const uint8* input_data_uint8,
+ const RuntimeShape& unextended_prev_activ_shape,
+ const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+ const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+ const int32* bias_data_int32,
+ const RuntimeShape& unextended_prev_state_shape,
+ const int16* prev_state_data_int16,
+ const RuntimeShape& unextended_output_state_shape,
+ int16* output_state_data_int16,
+ const RuntimeShape& unextended_output_activ_shape,
+ uint8* output_activ_data_uint8,
+ const RuntimeShape& unextended_concat_temp_shape,
+ uint8* concat_temp_data_uint8,
+ const RuntimeShape& unextended_activ_temp_shape,
+ int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label(
"LstmCell/quantized (8bit external, 16bit internal)");
+ int32 weights_zero_point = params.weights_zero_point;
+ int32 accum_multiplier = params.accum_multiplier;
+ int accum_shift = params.accum_shift;
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
// Gather dimensions information, and perform consistency checks.
- const int outer_size =
- MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
- output_state_dims, output_activ_dims);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int outer_size = MatchingFlatSizeSkipDim(
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+ output_activ_shape);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
- const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
const int fc_output_depth =
- MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
- const int fc_accum_depth = ArraySize(weights_dims, 0);
- TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+ const int fc_accum_depth = total_input_depth;
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
// Depth-concatenate prev_activ and input data together.
uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
prev_activ_data_uint8};
- Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
- Concatenation<FusedActivationFunctionType::kNone, uint8>(
- 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
- concat_temp_data_uint8, concat_temp_dims);
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape,
+ concat_temp_data_uint8);
// Implementation of the fully connected node inside the LSTM cell.
// The operands are 8-bit integers, the accumulators are internally 32bit
@@ -3458,10 +3878,10 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
bool gemm_already_performed = false;
#ifdef GEMMLOWP_NEON
if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) {
- GEMVForLstmCell(concat_temp_data_uint8, concat_temp_dims,
- weights_data_uint8, weights_dims, weights_zero_point,
- bias_data_int32, bias_dims, accum_multiplier, accum_shift,
- activ_temp_data_int16, activ_temp_dims);
+ GEMVForLstmCell(concat_temp_shape, concat_temp_data_uint8, weights_shape,
+ weights_data_uint8, weights_zero_point, bias_shape,
+ bias_data_int32, accum_multiplier, accum_shift,
+ activ_temp_shape, activ_temp_data_int16);
gemm_already_performed = true;
}
#endif
@@ -3650,28 +4070,35 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
- int outputs_count, Scalar* const* output_data,
- const Dims<4>* const* output_dims) {
- gemmlowp::ScopedProfilingLabel label("TensorFlowSplit");
- TFLITE_DCHECK_GE(outputs_count, 1);
- for (int i = 0; i < outputs_count; i++) {
- MatchingFlatSizeSkipDim(*output_dims[i], 0, input_dims);
- }
- const int outer_size = FlatSizeSkipDim(input_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- // For now we don't have a model with a TensorFlowSplit
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- const Scalar* input_ptr = input_data;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < outputs_count; ++i) {
- memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr,
- output_dims[i]->sizes[0] * sizeof(Scalar));
- input_ptr += output_dims[i]->sizes[0];
- }
- }
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
}
inline int NodeOffset(int b, int h, int w, int height, int width) {
@@ -4113,9 +4540,9 @@ inline void LocalResponseNormalization(
}
}
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
- float beta, float* output_data,
- const RuntimeShape& output_shape) {
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Softmax");
MatchingFlatSize(input_shape, output_shape);
@@ -4123,7 +4550,8 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Compute the exponential first, removing the max coefficient for numerical
// stability.
- out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
+ out_mat =
+ (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
// We are separating out the exp function so that exp can be vectorized.
out_mat = out_mat.array().exp();
// Normalize to get the activations.
@@ -4132,10 +4560,22 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
out_mat.array().rowwise() *= scale;
}
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_beta_multiplier = params.input_multiplier;
+ const int32 input_beta_left_shift = params.input_left_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4341,10 +4781,24 @@ inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
// TODO(myenik): This is the same as the reference implementation, not actually
// optimized yet.
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax");
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
@@ -4377,6 +4831,15 @@ inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(
@@ -4491,12 +4954,15 @@ log_x_for_x_greater_than_or_equal_to_1(
}
// Currently just a copy of the reference code.
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
+ const int32 input_multiplier = params.input_multiplier;
+ const int32 input_left_shift = params.input_left_shift;
+ const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
+ const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4554,7 +5020,7 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
std::max(diff_min - 1, // Note use of > below instead of >= above.
MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- kReverseShift * reverse_scaling_right_shift));
+ -reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff = static_cast<int32>(block_input_data[c]) - max_in_row;
@@ -4578,6 +5044,22 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
@@ -4587,11 +5069,23 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
@@ -4724,7 +5218,22 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const int16* input_data,
const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -4784,10 +5293,22 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy version.
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// Legacy version.
inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
int16* output_data, const RuntimeShape& output_shape) {
- Logistic(input_shape, input_data, output_shape, output_data);
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
@@ -4798,12 +5319,24 @@ inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
output_map.array() = input_map.array().tanh();
}
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
// Note that this is almost the exact same code as in Logistic().
gemmlowp::ScopedProfilingLabel label("Tanh");
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
@@ -4945,10 +5478,25 @@ inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
- int input_left_shift, int16* output_data,
- const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const int16* input_data, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh/Int16");
+ const int input_left_shift = params.input_left_shift;
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
@@ -5045,6 +5593,16 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
template <typename SrcT, typename DstT>
inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
const RuntimeShape& output_shape, DstT* output_data) {
@@ -5442,9 +6000,9 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -5491,9 +6049,9 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -5552,9 +6110,9 @@ inline void BatchToSpaceND(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input1_shape =
+ const RuntimeShape input1_shape =
RuntimeShape::ExtendedShape(4, unextended_input1_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_width = output_shape.Dims(2);
@@ -5638,8 +6196,10 @@ inline void PadImpl(const tflite::PadParams& op_params,
const P* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
gemmlowp::ScopedProfilingLabel label("Pad");
- RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
- RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+ const RuntimeShape ext_input_shape =
+ RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
@@ -5771,7 +6331,7 @@ inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Slice");
- RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
TFLITE_DCHECK_LE(op_params.begin_count, 4);
TFLITE_DCHECK_LE(op_params.size_count, 4);
@@ -5820,6 +6380,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
output_map.array() = input1_map.array().min(min_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -5831,59 +6401,56 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
output_map.array() = input1_map.array().max(max_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
template <typename T>
-void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
- const Dims<4>& filter_dims, int stride_width,
- int stride_height, int pad_width, int pad_height,
- const Dims<4>& output_dims, uint8 zero_byte,
- T* im2col_data) {
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
+template <typename T>
+void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& filter_shape,
+ const RuntimeShape& output_shape, T* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeIm2col");
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
TFLITE_DCHECK(im2col_data);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- MatchingArraySize(output_dims, 0, filter_dims, 0); // output_depth
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 0);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ MatchingDim(output_shape, 3, filter_shape, 3); // output_depth
// Construct the MxN sized im2col matrix.
// The rows M, are sub-ordered B x H x W
- Dims<4> row_dims;
- row_dims.sizes[0] = output_width;
- row_dims.sizes[1] = output_height;
- row_dims.sizes[2] = batches;
- row_dims.sizes[3] = 1;
- ComputeStrides(&row_dims);
-
+ const RuntimeShape row_shape({1, batches, output_height, output_width});
// The columns, N, are sub-ordered Kh x Kw x Din
- Dims<4> col_dims;
- col_dims.sizes[0] = input_depth;
- col_dims.sizes[1] = filter_width;
- col_dims.sizes[2] = filter_height;
- col_dims.sizes[3] = 1;
- ComputeStrides(&col_dims);
-
+ const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
// Use dimensions M and N to construct dims for indexing directly into im2col
- Dims<4> im2col_dims;
- im2col_dims.sizes[0] = FlatSize(col_dims);
- im2col_dims.sizes[1] = FlatSize(row_dims);
- im2col_dims.sizes[2] = 1;
- im2col_dims.sizes[3] = 1;
- ComputeStrides(&im2col_dims);
+ const RuntimeShape im2col_shape(
+ {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
// Build the im2col matrix by looping through all the input pixels,
// computing their influence on the output, rather than looping through all
// the output pixels. We therefore must initialize the im2col array to zero.
// This is potentially inefficient because we subsequently overwrite bytes
// set here. However, in practice memset is very fast and costs negligible.
- memset(im2col_data, zero_byte, FlatSize(im2col_dims) * sizeof(T));
+ memset(im2col_data, zero_byte, im2col_shape.FlatSize() * sizeof(T));
// Loop through the output batches
for (int batch = 0; batch < batches; ++batch) {
@@ -5903,11 +6470,11 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
if ((out_x >= 0) && (out_x < output_width)) {
// Copy the input elements of this pixel
T const* src =
- input_data + Offset(input_dims, 0, in_x, in_y, batch);
+ input_data + Offset(input_shape, batch, in_y, in_x, 0);
+ int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
+ int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
T* dst = im2col_data +
- Offset(im2col_dims,
- Offset(col_dims, 0, filter_x, filter_y, 0),
- Offset(row_dims, out_x, out_y, batch, 0), 0, 0);
+ Offset(im2col_shape, 0, 0, row_offset, col_offset);
memcpy(dst, src, input_depth * sizeof(T));
}
}
@@ -5918,31 +6485,71 @@ void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
}
}
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims, float* im2col_data,
- const Dims<4>& im2col_dims) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
+ const Dims<4>& filter_dims, int stride_width,
+ int stride_height, int pad_width, int pad_height,
+ const Dims<4>& output_dims, uint8 zero_byte,
+ T* im2col_data) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeIm2col(op_params, zero_byte, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), DimsToShape(output_dims),
+ im2col_data);
+}
+
+inline void TransposeConv(
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
gemmlowp::ScopedProfilingLabel label("TransposeConv");
// Note we could use transposed weights with forward conv for unstrided
// cases. But we are already getting good performance with this code as-is.
TFLITE_DCHECK(im2col_data);
- TransposeIm2col(input_data, input_dims, filter_dims, stride_width,
- stride_height, pad_width, pad_height, output_dims, 0,
- im2col_data);
+ TransposeIm2col(params, 0, input_shape, input_data, filter_shape,
+ output_shape, im2col_data);
const auto im2col_matrix_map =
- MapAsMatrixWithFirstDimAsRows(im2col_data, im2col_dims);
+ MapAsMatrixWithLastDimAsRows(im2col_data, im2col_shape);
const auto filter_matrix_map =
- MapAsMatrixWithLastDimAsCols(filter_data, filter_dims);
+ MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
auto output_matrix_map =
- MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ MapAsMatrixWithLastDimAsRows(output_data, output_shape);
Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
} // namespace optimized_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
index 9aabee5000..a8428528c9 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h
@@ -22,24 +22,36 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
@@ -52,25 +64,26 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
float total = 0.f;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
float input_value =
- input_data[Offset(input_dims, ic, in_x, in_y, b)];
+ input_data[Offset(input_shape, b, in_y, in_x, ic)];
float filter_value = filter_data[Offset(
- filter_dims, oc, filter_x, filter_y, 0)];
+ filter_shape, 0, filter_y, filter_x, oc)];
total += (input_value * filter_value);
}
}
}
float bias_value = 0.0f;
if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ bias_value = bias_data[oc];
}
- output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ output_data[Offset(output_shape, b, out_y, out_x, oc)] =
ActivationFunctionWithMinMax(total + bias_value,
output_activation_min,
output_activation_max);
@@ -81,6 +94,52 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
+ bias_dims, stride_width, stride_height, 1, 1, pad_width,
+ pad_height, depth_multiplier, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
@@ -97,6 +156,7 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
output_data, output_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
index d57739279f..e8fc566502 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <algorithm>
#include "fixedpoint/fixedpoint.h"
-#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -26,26 +25,45 @@ limitations under the License.
namespace tflite {
namespace reference_ops {
-inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, int depth_multiplier,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- TFLITE_DCHECK(output_depth == input_depth * depth_multiplier);
+// TODO(b/80418076): Move to legacy ops file, along with invocations.
+static constexpr int kDepthwiseReverseShift = -1;
+
+inline void DepthwiseConv(
+ const DepthwiseParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
for (int b = 0; b < batches; ++b) {
for (int out_y = 0; out_y < output_height; ++out_y) {
@@ -58,30 +76,31 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
int32 acc = 0;
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + filter_x;
- const int in_y = in_y_origin + filter_y;
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
// If the location is outside the bounds of the input image,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
int32 input_val =
- input_data[Offset(input_dims, ic, in_x, in_y, b)];
- int32 filter_val = filter_data[Offset(filter_dims, oc,
- filter_x, filter_y, 0)];
+ input_data[Offset(input_shape, b, in_y, in_x, ic)];
+ int32 filter_val = filter_data[Offset(
+ filter_shape, 0, filter_y, filter_x, oc)];
acc +=
(filter_val + filter_offset) * (input_val + input_offset);
}
}
}
if (bias_data) {
- acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)];
+ acc += bias_data[oc];
}
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, oc, out_x, out_y, b)] =
+ output_data[Offset(output_shape, b, out_y, out_x, oc)] =
static_cast<uint8>(acc);
}
}
@@ -90,6 +109,63 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height,
+ int dilation_width_factor, int dilation_height_factor,
+ int pad_width, int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.depth_multiplier = depth_multiplier;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kDepthwiseReverseShift * output_shift;
+
+ DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, int depth_multiplier,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, stride_width,
+ stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
+ output_offset, output_multiplier, output_shift,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
@@ -113,6 +189,7 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
output_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// Legacy, for compatibility with old checked-in code.
template <FusedActivationFunctionType Ac>
void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h b/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
new file mode 100644
index 0000000000..23325e8c4c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h
@@ -0,0 +1,460 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+namespace reference_ops {
+
+const int kReverseShift = -1;
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dims_count = output_shape.DimensionsCount();
+ const int weights_dims_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
+ output_shape, output_dims_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ float total = 0.f;
+ for (int d = 0; d < accum_depth; ++d) {
+ total += input_data[b * accum_depth + d] *
+ weights_data[out_c * accum_depth + d];
+ }
+ float bias_value = 0.0f;
+ if (bias_data) {
+ bias_value = bias_data[out_c];
+ }
+ output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
+ total + bias_value, output_activation_min, output_activation_max);
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims,
+ float output_activation_min,
+ float output_activation_max, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), weights_data,
+ DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const float* input_data, const Dims<4>& input_dims,
+ const float* weights_data, const Dims<4>& weights_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+ FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
+ bias_dims, output_activation_min, output_activation_max,
+ output_data, output_dims);
+}
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ int32 acc = 0;
+ for (int d = 0; d < accum_depth; ++d) {
+ int32 input_val = input_data[b * accum_depth + d];
+ int32 filter_val = filter_data[out_c * accum_depth + d];
+ acc += (filter_val + filter_offset) * (input_val + input_offset);
+ }
+ if (bias_data) {
+ acc += bias_data[out_c];
+ }
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[out_c + output_depth * b] = static_cast<uint8>(acc);
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, void* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void FullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ TFLITE_DCHECK_EQ(output_offset, 0);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int filter_dim_count = filter_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(filter_shape, filter_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = filter_shape.Dims(filter_dim_count - 1);
+ for (int b = 0; b < batches; ++b) {
+ for (int out_c = 0; out_c < output_depth; ++out_c) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum = bias_data[out_c];
+ // Accumulation loop.
+ for (int d = 0; d < accum_depth; ++d) {
+ int16 input_val = input_data[b * accum_depth + d] + input_offset;
+ int16 filter_val = filter_data[out_c * accum_depth + d] + filter_offset;
+ accum += filter_val * input_val;
+ }
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ accum =
+ MultiplyByQuantizedMultiplier(accum, output_multiplier, output_shift);
+ // Saturate, cast to int16, and store to output array.
+ accum = std::max(accum, output_activation_min - output_offset);
+ accum = std::min(accum, output_activation_max - output_offset);
+ accum += output_offset;
+ output_data[out_c + output_depth * b] = accum;
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, int16* output_data,
+ const Dims<4>& output_dims, void* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ FullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(output_dims), output_data,
+ gemm_context);
+}
+
+inline void ShuffledFullyConnected(
+ const FullyConnectedParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& weights_shape,
+ const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int16* output_data, uint8* shuffled_input_workspace_data,
+ void* gemm_context) {
+ (void)gemm_context; // only used in optimized code.
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+
+ TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+ TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1);
+ // TODO(benoitjacob): This really should be:
+ // const int batches = ArraySize(output_dims, 1);
+ // but the current --variable_batch hack consists in overwriting the 3rd
+ // dimension with the runtime batch size, as we don't keep track for each
+ // array of which dimension is the batch dimension in it.
+ const int output_dim_count = output_shape.DimensionsCount();
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1);
+ const int output_depth = MatchingDim(weights_shape, weights_dim_count - 2,
+ output_shape, output_dim_count - 1);
+ const int accum_depth = weights_shape.Dims(weights_dim_count - 1);
+ TFLITE_DCHECK((accum_depth % 16) == 0);
+ TFLITE_DCHECK((output_depth % 4) == 0);
+
+ // Shuffling and xoring of input activations into the workspace buffer
+ uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
+ if (batches == 1) {
+ for (int i = 0; i < accum_depth; i++) {
+ shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
+ }
+ } else if (batches == 4) {
+ for (int c = 0; c < accum_depth; c += 16) {
+ for (int b = 0; b < 4; b++) {
+ const uint8* src_data_ptr = input_data + b * accum_depth + c;
+ for (int j = 0; j < 16; j++) {
+ uint8 src_val = *src_data_ptr++;
+ // Flip the sign bit, so that the kernel will only need to
+ // reinterpret these uint8 values as int8, getting for free the
+ // subtraction of the zero_point value 128.
+ uint8 dst_val = src_val ^ 0x80;
+ *shuffled_input_workspace_ptr++ = dst_val;
+ }
+ }
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+
+ // Actual computation
+ if (batches == 1) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4] = {0};
+ // Accumulation loop.
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_data[d + j];
+ int8 weights_val = *shuffled_weights_ptr++;
+ accum[i] += weights_val * input_val;
+ }
+ }
+ }
+ for (int i = 0; i < 4; i++) {
+ // Add bias value
+ int32 acc = accum[i] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc =
+ MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_ptr[c + i] = acc;
+ }
+ }
+ } else if (batches == 4) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ const int8* shuffled_input_ptr = shuffled_input_data;
+ // Accumulation loop.
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4][4];
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ accum[i][b] = 0;
+ }
+ }
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_ptr[16 * b + j];
+ int8 weights_val = shuffled_weights_ptr[16 * i + j];
+ accum[i][b] += weights_val * input_val;
+ }
+ }
+ }
+ shuffled_input_ptr += 64;
+ shuffled_weights_ptr += 64;
+ }
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ // Add bias value
+ int32 acc = accum[i][b] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The
+ // quantized multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_ptr[b * output_depth + c + i] = acc;
+ }
+ }
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void ShuffledFullyConnected(
+ const uint8* input_data, const Dims<4>& input_dims,
+ const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
+ const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims,
+ uint8* shuffled_input_workspace_data, void* gemm_context) {
+ tflite::FullyConnectedParams op_params;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(weights_dims), shuffled_weights_data,
+ DimsToShape(bias_dims), bias_data,
+ DimsToShape(output_dims), output_data,
+ shuffled_input_workspace_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims, void* gemm_context) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
+ filter_offset, bias_data, bias_dims, output_offset,
+ output_multiplier, output_shift, output_activation_min,
+ output_activation_max, output_data, output_dims, gemm_context);
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_FULLY_CONNECTED_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index 77e60adc18..70d25c4bd9 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -55,7 +55,7 @@ void PortableSymmetricQuantizeFloats(const float* values, const int size,
return;
}
*scaling_factor = range / kScale;
- const float scaling_factor_inv = 1.0f / *scaling_factor;
+ const float scaling_factor_inv = kScale / range;
for (int i = 0; i < size; ++i) {
const int32_t quantized_value =
static_cast<int32_t>(TfLiteRound(values[i] * scaling_factor_inv));
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 0abacf85e1..7a5535489a 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -28,6 +28,8 @@ limitations under the License.
#include "public/gemmlowp.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
@@ -98,18 +100,6 @@ gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
namespace reference_ops {
-// TODO(b/80247582) Remove this constant.
-// This will be phased out as the shifts are revised with more thought. Use of a
-// constant enables us to track progress on this work.
-//
-// Used mainly to convert from old-style shifts (right) to new-style (left).
-static constexpr int kReverseShift = -1;
-
-inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
- return RuntimeShape(
- {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
-}
-
inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
shape->BuildFrom(
{dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
@@ -168,28 +158,38 @@ SaturatingRoundingMultiplyByPOTParam(
SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent));
}
-inline void Conv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims,
- float* im2col_data, const Dims<4>& im2col_dims) {
- (void)im2col_data; // only used in optimized code.
- (void)im2col_dims; // only used in optimized code.
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& bias_shape,
+ const float* bias_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape,
+ float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
if (bias_data) {
- TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0));
- }
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -207,11 +207,11 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ float input_value = input_data[Offset(
+ input_shape, batch, in_y, in_x, in_channel)];
float filter_value =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
total += (input_value * filter_value);
}
}
@@ -219,9 +219,9 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
}
float bias_value = 0.0f;
if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ bias_value = bias_data[out_channel];
}
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
ActivationFunctionWithMinMax(total + bias_value,
output_activation_min,
output_activation_max);
@@ -231,6 +231,35 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ float* im2col_data, const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
@@ -248,6 +277,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
im2col_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -264,6 +294,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
@@ -277,31 +308,45 @@ void Conv(const float* input_data, const Dims<4>& input_dims,
output_dims, im2col_data, im2col_dims);
}
-inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int stride_width, int stride_height, int dilation_width_factor,
- int dilation_height_factor, int pad_width, int pad_height,
- int32 output_offset, int32 output_multiplier, int output_shift,
- int32 output_activation_min, int32 output_activation_max,
- uint8* output_data, const Dims<4>& output_dims,
- uint8* im2col_data, const Dims<4>& im2col_dims,
- gemmlowp::GemmContext* gemm_context) {
+inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& filter_shape,
+ const uint8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ uint8* output_data, const RuntimeShape& im2col_shape,
+ uint8* im2col_data, gemmlowp::GemmContext* gemm_context) {
(void)im2col_data; // only used in optimized code.
- (void)im2col_dims; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
(void)gemm_context; // only used in optimized code.
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth =
- MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ if (bias_data) {
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+ }
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
for (int batch = 0; batch < batches; ++batch) {
for (int out_y = 0; out_y < output_height; ++out_y) {
for (int out_x = 0; out_x < output_width; ++out_x) {
@@ -319,11 +364,11 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
// use zero as a default value.
if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
(in_y < input_height)) {
- int32 input_val = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ int32 input_val = input_data[Offset(input_shape, batch, in_y,
+ in_x, in_channel)];
int32 filter_val =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
acc +=
(filter_val + filter_offset) * (input_val + input_offset);
}
@@ -331,14 +376,14 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
}
if (bias_data) {
- acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)];
+ acc += bias_data[out_channel];
}
acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- kReverseShift * output_shift);
+ output_shift);
acc += output_offset;
acc = std::max(acc, output_activation_min);
acc = std::min(acc, output_activation_max);
- output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] =
+ output_data[Offset(output_shape, batch, out_y, out_x, out_channel)] =
static_cast<uint8>(acc);
}
}
@@ -346,6 +391,44 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
+ int32 input_offset, const uint8* filter_data,
+ const Dims<4>& filter_dims, int32 filter_offset,
+ const int32* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int dilation_width_factor,
+ int dilation_height_factor, int pad_width, int pad_height,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims,
+ uint8* im2col_data, const Dims<4>& im2col_dims,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+ op_params.dilation_width_factor = dilation_width_factor;
+ op_params.dilation_height_factor = dilation_height_factor;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
+ filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data, gemm_context);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
int32 input_offset, const uint8* filter_data,
const Dims<4>& filter_dims, int32 filter_offset,
@@ -364,6 +447,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -393,6 +477,7 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void Conv(const uint8* input_data, const Dims<4>& input_dims,
@@ -419,9 +504,9 @@ inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
@@ -472,9 +557,9 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
T* output_data) {
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int input_depth = input_shape.Dims(3);
@@ -516,320 +601,6 @@ inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
}
}
-inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims,
- float output_activation_min,
- float output_activation_max, float* output_data,
- const Dims<4>& output_dims) {
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- float total = 0.f;
- for (int d = 0; d < accum_depth; ++d) {
- total += input_data[b * accum_depth + d] *
- weights_data[out_c * accum_depth + d];
- }
- float bias_value = 0.0f;
- if (bias_data) {
- bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
- }
- output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax(
- total + bias_value, output_activation_min, output_activation_max);
- }
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const float* input_data, const Dims<4>& input_dims,
- const float* weights_data, const Dims<4>& weights_dims,
- const float* bias_data, const Dims<4>& bias_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
- FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
- bias_dims, output_activation_min, output_activation_max,
- output_data, output_dims);
-}
-
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- int32 acc = 0;
- for (int d = 0; d < accum_depth; ++d) {
- int32 input_val = input_data[b * accum_depth + d];
- int32 filter_val = filter_data[out_c * accum_depth + d];
- acc += (filter_val + filter_offset) * (input_val + input_offset);
- }
- if (bias_data) {
- acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)];
- }
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- kReverseShift * output_shift);
- acc += output_offset;
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_data[out_c + output_depth * b] = static_cast<uint8>(acc);
- }
- }
-}
-
-inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, int16* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- TFLITE_DCHECK_EQ(output_offset, 0);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(filter_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims));
- for (int b = 0; b < batches; ++b) {
- for (int out_c = 0; out_c < output_depth; ++out_c) {
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum = bias_data[out_c];
- // Accumulation loop.
- for (int d = 0; d < accum_depth; ++d) {
- int16 input_val = input_data[b * accum_depth + d] + input_offset;
- int16 filter_val = filter_data[out_c * accum_depth + d] + filter_offset;
- accum += filter_val * input_val;
- }
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The quantized
- // multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- accum = MultiplyByQuantizedMultiplier(accum, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- accum = std::max(accum, output_activation_min - output_offset);
- accum = std::min(accum, output_activation_max - output_offset);
- accum += output_offset;
- output_data[out_c + output_depth * b] = accum;
- }
- }
-}
-
-inline void ShuffledFullyConnected(
- const uint8* input_data, const Dims<4>& input_dims,
- const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
- const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
- int output_shift, int32 output_activation_min, int32 output_activation_max,
- int16* output_data, const Dims<4>& output_dims,
- uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
- (void)gemm_context; // only used in optimized code.
-
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- // TODO(benoitjacob): This really should be:
- // const int batches = ArraySize(output_dims, 1);
- // but the current --variable_batch hack consists in overwriting the 3rd
- // dimension with the runtime batch size, as we don't keep track for each
- // array of which dimension is the batch dimension in it.
- const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) *
- ArraySize(output_dims, 3);
- const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0);
- const int accum_depth = ArraySize(weights_dims, 0);
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- TFLITE_DCHECK((accum_depth % 16) == 0);
- TFLITE_DCHECK((output_depth % 4) == 0);
-
- // Shuffling and xoring of input activations into the workspace buffer
- uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
- if (batches == 1) {
- for (int i = 0; i < accum_depth; i++) {
- shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
- }
- } else if (batches == 4) {
- for (int c = 0; c < accum_depth; c += 16) {
- for (int b = 0; b < 4; b++) {
- const uint8* src_data_ptr = input_data + b * accum_depth + c;
- for (int j = 0; j < 16; j++) {
- uint8 src_val = *src_data_ptr++;
- // Flip the sign bit, so that the kernel will only need to
- // reinterpret these uint8 values as int8, getting for free the
- // subtraction of the zero_point value 128.
- uint8 dst_val = src_val ^ 0x80;
- *shuffled_input_workspace_ptr++ = dst_val;
- }
- }
- }
- } else {
- TFLITE_DCHECK(false);
- return;
- }
-
- // Actual computation
- if (batches == 1) {
- int16* output_ptr = output_data;
- // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
- // so that just reinterpreting them as int8 values is equivalent to
- // subtracting 128 from them, thus implementing for free the subtraction of
- // the zero_point value 128.
- const int8* shuffled_weights_ptr =
- reinterpret_cast<const int8*>(shuffled_weights_data);
- // Likewise, we preshuffled and pre-xored the input data above.
- const int8* shuffled_input_data =
- reinterpret_cast<const int8*>(shuffled_input_workspace_data);
- for (int c = 0; c < output_depth; c += 4) {
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum[4] = {0};
- // Accumulation loop.
- for (int d = 0; d < accum_depth; d += 16) {
- for (int i = 0; i < 4; i++) {
- for (int j = 0; j < 16; j++) {
- int8 input_val = shuffled_input_data[d + j];
- int8 weights_val = *shuffled_weights_ptr++;
- accum[i] += weights_val * input_val;
- }
- }
- }
- for (int i = 0; i < 4; i++) {
- // Add bias value
- int acc = accum[i] + bias_data[c + i];
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The quantized
- // multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_ptr[c + i] = acc;
- }
- }
- } else if (batches == 4) {
- int16* output_ptr = output_data;
- // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
- // so that just reinterpreting them as int8 values is equivalent to
- // subtracting 128 from them, thus implementing for free the subtraction of
- // the zero_point value 128.
- const int8* shuffled_weights_ptr =
- reinterpret_cast<const int8*>(shuffled_weights_data);
- // Likewise, we preshuffled and pre-xored the input data above.
- const int8* shuffled_input_data =
- reinterpret_cast<const int8*>(shuffled_input_workspace_data);
- for (int c = 0; c < output_depth; c += 4) {
- const int8* shuffled_input_ptr = shuffled_input_data;
- // Accumulation loop.
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum[4][4];
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- accum[i][b] = 0;
- }
- }
- for (int d = 0; d < accum_depth; d += 16) {
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- for (int j = 0; j < 16; j++) {
- int8 input_val = shuffled_input_ptr[16 * b + j];
- int8 weights_val = shuffled_weights_ptr[16 * i + j];
- accum[i][b] += weights_val * input_val;
- }
- }
- }
- shuffled_input_ptr += 64;
- shuffled_weights_ptr += 64;
- }
- for (int i = 0; i < 4; i++) {
- for (int b = 0; b < 4; b++) {
- // Add bias value
- int acc = accum[i][b] + bias_data[c + i];
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The
- // quantized multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- -output_shift);
- // Saturate, cast to int16, and store to output array.
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_ptr[b * output_depth + c + i] = acc;
- }
- }
- }
- } else {
- TFLITE_DCHECK(false);
- return;
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
- int32 input_offset, const uint8* filter_data,
- const Dims<4>& filter_dims, int32 filter_offset,
- const int32* bias_data, const Dims<4>& bias_dims,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
- static_assert(Ac == FusedActivationFunctionType::kNone ||
- Ac == FusedActivationFunctionType::kRelu ||
- Ac == FusedActivationFunctionType::kRelu6 ||
- Ac == FusedActivationFunctionType::kRelu1,
- "");
- if (Ac == FusedActivationFunctionType::kNone) {
- TFLITE_DCHECK_EQ(output_activation_min, 0);
- TFLITE_DCHECK_EQ(output_activation_max, 255);
- }
- FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
- filter_offset, bias_data, bias_dims, output_offset,
- output_multiplier, output_shift, output_activation_min,
- output_activation_max, output_data, output_dims, gemm_context);
-}
-
inline void Relu(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -950,6 +721,7 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_inv_sqrt <<= -*output_shift;
*output_shift = 0;
}
+ // Convert right shift (right is positive) to left shift.
*output_shift *= kReverseShift;
}
@@ -1117,7 +889,7 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1158,7 +930,7 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1200,7 +972,7 @@ inline void BroadcastAdd4DSlow(const ArithmeticParams& params,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1350,7 +1122,7 @@ void BroadcastMul4DSlow(const ArithmeticParams& params,
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -1483,7 +1255,7 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params,
// The input shapes are extended as part of NdArrayDesc initialization.
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
@@ -1579,7 +1351,7 @@ void BroadcastDiv4DSlow(const ArithmeticParams& params,
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -1708,12 +1480,12 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1749,12 +1521,12 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const uint8* input2_data,
const RuntimeShape& output_shape,
uint8* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1813,12 +1585,12 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const int32* input2_data,
const RuntimeShape& output_shape,
int32* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1853,12 +1625,12 @@ void BroadcastSub4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -1897,7 +1669,7 @@ void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
- RuntimeShape extended_output_shape =
+ const RuntimeShape extended_output_shape =
RuntimeShape::ExtendedShape(4, output_shape);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
@@ -2204,56 +1976,90 @@ void DepthConcatenation(const Scalar* const* input_data,
output_data, output_dims);
}
-inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
- const float* prev_activ_data,
- const Dims<4>& prev_activ_dims, const float* weights_data,
- const Dims<4>& weights_dims, const float* bias_data,
- const Dims<4>& bias_dims, const float* prev_state_data,
- const Dims<4>& prev_state_dims, float* output_state_data,
- const Dims<4>& output_state_dims, float* output_activ_data,
- const Dims<4>& output_activ_dims, float* concat_temp_data,
- const Dims<4>& concat_temp_dims, float* activ_temp_data,
- const Dims<4>& activ_temp_dims) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const float* input_data, const RuntimeShape& unextended_prev_activ_shape,
+ const float* prev_activ_data, const RuntimeShape& weights_shape,
+ const float* weights_data, const RuntimeShape& unextended_bias_shape,
+ const float* bias_data, const RuntimeShape& unextended_prev_state_shape,
+ const float* prev_state_data,
+ const RuntimeShape& unextended_output_state_shape, float* output_state_data,
+ const RuntimeShape& unextended_output_activ_shape, float* output_activ_data,
+ const RuntimeShape& unextended_concat_temp_shape, float* concat_temp_data,
+ const RuntimeShape& unextended_activ_temp_shape, float* activ_temp_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
+
+ const int weights_dim_count = weights_shape.DimensionsCount();
const int batches =
- MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3,
- output_state_dims, 3, output_activ_dims, 3);
+ MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
+ output_state_shape, 0, output_activ_shape, 0);
const int height =
- MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2,
- output_state_dims, 2, output_activ_dims, 2);
+ MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
+ output_state_shape, 1, output_activ_shape, 1);
const int width =
- MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1,
- output_state_dims, 1, output_activ_dims, 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
+ output_state_shape, 2, output_activ_shape, 2);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
// Concatenate prev_activ and input data together
std::vector<float const*> concat_input_arrays_data;
- std::vector<Dims<4> const*> concat_input_arrays_dims;
+ std::vector<RuntimeShape const*> concat_input_arrays_shapes;
concat_input_arrays_data.push_back(input_data);
concat_input_arrays_data.push_back(prev_activ_data);
- concat_input_arrays_dims.push_back(&input_dims);
- concat_input_arrays_dims.push_back(&prev_activ_dims);
- Concatenation<FusedActivationFunctionType::kNone, float>(
- 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]),
- concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims);
+ concat_input_arrays_shapes.push_back(&input_shape);
+ concat_input_arrays_shapes.push_back(&prev_activ_shape);
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = concat_input_arrays_data.size();
+ Concatenation(concat_params, &(concat_input_arrays_shapes[0]),
+ &(concat_input_arrays_data[0]), concat_temp_shape,
+ concat_temp_data);
// Fully connected
- FullyConnected<FusedActivationFunctionType::kNone>(
- concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data,
- bias_dims, activ_temp_data, activ_temp_dims);
+ tflite::FullyConnectedParams fc_params;
+ fc_params.float_activation_min = std::numeric_limits<float>::lowest();
+ fc_params.float_activation_max = std::numeric_limits<float>::max();
+ FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape,
+ weights_data, bias_shape, bias_data, activ_temp_shape,
+ activ_temp_data);
// Memory state update (the LSTM "guts")
for (int b = 0; b < batches; ++b) {
@@ -2262,24 +2068,24 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
for (int c = 0; c < output_depth; ++c) {
const float input_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 0 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 0 * output_depth + c)]));
const float new_input = std::tanh(activ_temp_data[Offset(
- activ_temp_dims, 1 * output_depth + c, w, h, b)]);
+ activ_temp_shape, b, h, w, 1 * output_depth + c)]);
const float forget_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 2 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 2 * output_depth + c)]));
const float output_gate =
1.f /
- (1.f + std::exp(-activ_temp_data[Offset(
- activ_temp_dims, 3 * output_depth + c, w, h, b)]));
+ (1.f + std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w,
+ 3 * output_depth + c)]));
const float new_state =
input_gate * new_input +
forget_gate *
- prev_state_data[Offset(prev_state_dims, c, w, h, b)];
- output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state;
- output_activ_data[Offset(output_activ_dims, c, w, h, b)] =
+ prev_state_data[Offset(prev_state_shape, b, h, w, c)];
+ output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
+ output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
output_gate * std::tanh(new_state);
}
}
@@ -2287,6 +2093,31 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
+ const float* prev_activ_data,
+ const Dims<4>& prev_activ_dims, const float* weights_data,
+ const Dims<4>& weights_dims, const float* bias_data,
+ const Dims<4>& bias_dims, const float* prev_state_data,
+ const Dims<4>& prev_state_dims, float* output_state_data,
+ const Dims<4>& output_state_dims, float* output_activ_data,
+ const Dims<4>& output_activ_dims, float* concat_temp_data,
+ const Dims<4>& concat_temp_dims, float* activ_temp_data,
+ const Dims<4>& activ_temp_dims) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
+
+ LstmCell(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(prev_activ_dims), prev_activ_data,
+ DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
+ bias_data, DimsToShape(prev_state_dims), prev_state_data,
+ DimsToShape(output_state_dims), output_state_data,
+ DimsToShape(output_activ_dims), output_activ_data,
+ DimsToShape(concat_temp_dims), concat_temp_data,
+ DimsToShape(activ_temp_dims), activ_temp_data);
+}
+
// Quantized LSTM cell implementation.
// The quantization of the input, output arrays is as follows:
// - The input activations are quantized as uint8 on the interval
@@ -2372,52 +2203,90 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
// aiming for 16-bit fixed-point quantization of these internal nodes here.
//
template <int StateIntegerBits>
-void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
- const uint8* prev_activ_data_uint8,
- const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
- const Dims<4>& weights_dims, const int32* bias_data_int32,
- const Dims<4>& bias_dims, const int16* prev_state_data_int16,
- const Dims<4>& prev_state_dims, int16* output_state_data_int16,
- const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
- const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
- const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
- const Dims<4>& activ_temp_dims, int32 weights_zero_point,
- int32 accum_multiplier, int accum_shift,
- gemmlowp::GemmContext* gemm_context) {
+inline void LstmCell(
+ const LstmCellParams& params, const RuntimeShape& unextended_input_shape,
+ const uint8* input_data_uint8,
+ const RuntimeShape& unextended_prev_activ_shape,
+ const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape,
+ const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape,
+ const int32* bias_data_int32,
+ const RuntimeShape& unextended_prev_state_shape,
+ const int16* prev_state_data_int16,
+ const RuntimeShape& unextended_output_state_shape,
+ int16* output_state_data_int16,
+ const RuntimeShape& unextended_output_activ_shape,
+ uint8* output_activ_data_uint8,
+ const RuntimeShape& unextended_concat_temp_shape,
+ uint8* concat_temp_data_uint8,
+ const RuntimeShape& unextended_activ_temp_shape,
+ int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
+ int32 weights_zero_point = params.weights_zero_point;
+ int32 accum_multiplier = params.accum_multiplier;
+ int accum_shift = params.accum_shift;
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape prev_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
+ const RuntimeShape bias_shape =
+ RuntimeShape::ExtendedShape(4, unextended_bias_shape);
+ const RuntimeShape prev_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
+ const RuntimeShape output_state_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
+ const RuntimeShape output_activ_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
+ const RuntimeShape concat_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
+ const RuntimeShape activ_temp_shape =
+ RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
+ TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
// Gather dimensions information, and perform consistency checks.
- const int outer_size =
- MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prev_state_dims,
- output_state_dims, output_activ_dims);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
- const int input_depth = ArraySize(input_dims, 0);
- const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
+ const int weights_dim_count = weights_shape.DimensionsCount();
+ const int outer_size = MatchingFlatSizeSkipDim(
+ input_shape, 3, prev_activ_shape, prev_state_shape, output_state_shape,
+ output_activ_shape);
+ const int input_depth = input_shape.Dims(3);
+ const int prev_activ_depth = prev_activ_shape.Dims(3);
const int total_input_depth = prev_activ_depth + input_depth;
- TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
- TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
- 1);
+ TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1),
+ total_input_depth);
const int intern_activ_depth =
- MatchingArraySize(weights_dims, 1, bias_dims, 0);
- TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
+ MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
+ TFLITE_DCHECK_EQ(weights_shape.FlatSize(),
+ intern_activ_depth * total_input_depth);
+ TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
+ TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
const int output_depth =
- MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0,
- output_state_dims, 0, output_activ_dims, 0);
- TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
- const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
+ MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
+ 3, output_activ_shape, 3);
+ TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
+ const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
const int fc_output_depth =
- MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
- const int fc_accum_depth = ArraySize(weights_dims, 0);
- TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
+ MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
+ const int fc_accum_depth = total_input_depth;
+ TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
// Depth-concatenate prev_activ and input data together.
uint8 const* concat_input_arrays_data[2] = {input_data_uint8,
prev_activ_data_uint8};
- Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
- Concatenation<FusedActivationFunctionType::kNone, uint8>(
- 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
- concat_temp_data_uint8, concat_temp_dims);
+ const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape,
+ &prev_activ_shape};
+ tflite::ConcatenationParams concat_params;
+ concat_params.axis = 3;
+ concat_params.inputs_count = 2;
+ Concatenation(concat_params, concat_input_arrays_shapes,
+ concat_input_arrays_data, concat_temp_shape,
+ concat_temp_data_uint8);
// Implementation of the fully connected node inside the LSTM cell.
// The operands are 8-bit integers, the accumulators are internally 32bit
@@ -2523,6 +2392,37 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <int StateIntegerBits>
+void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
+ const uint8* prev_activ_data_uint8,
+ const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
+ const Dims<4>& weights_dims, const int32* bias_data_int32,
+ const Dims<4>& bias_dims, const int16* prev_state_data_int16,
+ const Dims<4>& prev_state_dims, int16* output_state_data_int16,
+ const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
+ const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
+ const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
+ const Dims<4>& activ_temp_dims, int32 weights_zero_point,
+ int32 accum_multiplier, int accum_shift,
+ gemmlowp::GemmContext* gemm_context) {
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights_zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
+
+ LstmCell<StateIntegerBits>(
+ op_params, DimsToShape(input_dims), input_data_uint8,
+ DimsToShape(prev_activ_dims), prev_activ_data_uint8,
+ DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
+ bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
+ DimsToShape(output_state_dims), output_state_data_int16,
+ DimsToShape(output_activ_dims), output_activ_data_uint8,
+ DimsToShape(concat_temp_dims), concat_temp_data_uint8,
+ DimsToShape(activ_temp_dims), activ_temp_data_int16, gemm_context);
+}
+
template <typename Scalar>
void Split(const SplitParams& params, const RuntimeShape& input_shape,
const Scalar* input_data, const RuntimeShape* const* output_shapes,
@@ -2902,121 +2802,9 @@ inline void LocalResponseNormalization(
}
}
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
- float beta, float* output_data,
- const RuntimeShape& output_shape) {
- const int trailing_dim = input_shape.DimensionsCount() - 1;
- const int outer_size =
- MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
- const int depth =
- MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-
- for (int i = 0; i < outer_size; ++i) {
- // Find max element value which we'll use to ensure numerical stability
- // taking advantage of the following equality:
- // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
- float max = std::numeric_limits<float>::lowest();
- for (int c = 0; c < depth; ++c) {
- max = std::max(max, input_data[i * depth + c]);
- }
-
- // Compute sum.
- float sum = 0.f;
- for (int c = 0; c < depth; ++c) {
- sum += std::exp((input_data[i * depth + c] - max) * beta);
- }
-
- // Compute result.
- for (int c = 0; c < depth; ++c) {
- output_data[i * depth + c] =
- std::exp((input_data[i * depth + c] - max) * beta) / sum;
- }
- }
-}
-
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
- const RuntimeShape& output_shape) {
- // The representation chosen for the input to the exp() function is Q5.26.
- // We need to leave extra space since values that we skip might be as large as
- // -32 before multiplying by input_beta_multiplier, and therefore as large as
- // -16 afterwards. Note that exp(-8) is definitely not insignificant to
- // accumulation, but exp(-16) definitely is.
- static const int kScaledDiffIntegerBits = 5;
- static const int kAccumulationIntegerBits = 12;
- using FixedPointScaledDiff =
- gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
- using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
- using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
-
- const int trailing_dim = input_shape.DimensionsCount() - 1;
- const int outer_size =
- MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
- const int depth =
- MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
-
- for (int i = 0; i < outer_size; ++i) {
- uint8 max_in_row = 0;
- for (int c = 0; c < depth; ++c) {
- max_in_row = std::max(max_in_row, input_data[i * depth + c]);
- }
-
- FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
- for (int c = 0; c < depth; ++c) {
- int32 input_diff =
- static_cast<int32>(input_data[i * depth + c]) - max_in_row;
- if (input_diff >= diff_min) {
- const int32 input_diff_rescaled =
- MultiplyByQuantizedMultiplierGreaterThanOne(
- input_diff, input_beta_multiplier, input_beta_left_shift);
- const FixedPointScaledDiff scaled_diff_f8 =
- FixedPointScaledDiff::FromRaw(input_diff_rescaled);
- sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
- exp_on_negative_values(scaled_diff_f8));
- }
- }
-
- int32 fixed_sum_of_exps = sum_of_exps.raw();
- int headroom_plus_one =
- CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
- // This is the number of bits to the left of the binary point above 1.0.
- // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
- // no later adjustment will be needed.
- int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
- int32 shifted_sum_minus_one = static_cast<int32>(
- (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
- (static_cast<uint32>(1) << 31));
-
- FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
- FixedPoint0::FromRaw(shifted_sum_minus_one));
-
- for (int c = 0; c < depth; ++c) {
- int32 input_diff =
- static_cast<int32>(input_data[i * depth + c]) - max_in_row;
- if (input_diff >= diff_min) {
- const int32 input_diff_rescaled =
- MultiplyByQuantizedMultiplierGreaterThanOne(
- input_diff, input_beta_multiplier, input_beta_left_shift);
- const FixedPointScaledDiff scaled_diff_f8 =
- FixedPointScaledDiff::FromRaw(input_diff_rescaled);
-
- FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
- int32 unsat_output = gemmlowp::RoundingDivideByPOT(
- (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
-
- output_data[i * depth + c] = static_cast<uint8>(
- std::max(std::min(unsat_output, static_cast<int32>(255)), 0));
-
- } else {
- output_data[i * depth + c] = 0;
- }
- }
- }
-}
-
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -3046,6 +2834,15 @@ inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
// Although currently the name of this function says that it cannot handle
// values less than 1, in practice it can handle as low as 1/x_max, where
// x_max is the largest representable input. In other words, the output range
@@ -3161,16 +2958,19 @@ log_x_for_x_greater_than_or_equal_to_1(
input_val);
}
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_multiplier = params.input_multiplier;
+ const int32 input_left_shift = params.input_left_shift;
+ const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
+ const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
- // We need to leave extra space since values that we skip might be as large as
- // -32 before multiplying by input_beta_multiplier, and therefore as large as
- // -16 afterwards. Note that exp(-8) is definitely not insignificant to
- // accumulation, but exp(-16) definitely is.
+ // We need to leave extra space since values that we skip might be as large
+ // as -32 before multiplying by input_beta_multiplier, and therefore as
+ // large as -16 afterwards. Note that exp(-8) is definitely not
+ // insignificant to accumulation, but exp(-16) definitely is.
static constexpr int kScaledDiffIntegerBits = 5;
static constexpr int kAccumulationIntegerBits = 12;
static constexpr int kOutputIntegerBits = 4;
@@ -3222,7 +3022,7 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
std::max(diff_min - 1, // Note use of > below instead of >= above.
MultiplyByQuantizedMultiplierSmallerThanOneExp(
rescaled_diff_min, reverse_scaling_divisor,
- kReverseShift * reverse_scaling_right_shift));
+ -reverse_scaling_right_shift));
for (int c = 0; c < depth; ++c) {
int32 input_diff =
@@ -3247,6 +3047,22 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3258,10 +3074,22 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
}
}
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3296,7 +3124,22 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const int16* input_data,
const RuntimeShape& output_shape, int16* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3314,6 +3157,15 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3325,10 +3177,22 @@ inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
}
}
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int32 output_zero_point = 128;
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3365,9 +3229,24 @@ inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
- int input_left_shift, int16* output_data,
- const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const int16* input_data, const RuntimeShape& output_shape,
+ int16* output_data) {
+ const int input_left_shift = params.input_left_shift;
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
@@ -3398,6 +3277,16 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Dequantize(const tflite::DequantizationParams& op_params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, float* output_data) {
@@ -3543,11 +3432,11 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_size_shape =
+ const RuntimeShape output_size_shape =
RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
@@ -3606,9 +3495,9 @@ inline void SpaceToBatchND(
const RuntimeShape& unextended_output_shape, T* output_data) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input1_shape =
+ const RuntimeShape input1_shape =
RuntimeShape::ExtendedShape(4, unextended_input1_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int depth = input1_shape.Dims(3);
@@ -3663,9 +3552,9 @@ inline void BatchToSpaceND(
const RuntimeShape& unextended_output_shape, T* output_data) {
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input1_shape =
+ const RuntimeShape input1_shape =
RuntimeShape::ExtendedShape(4, unextended_input1_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_width = output_shape.Dims(2);
@@ -3719,8 +3608,10 @@ inline void PadImpl(const tflite::PadParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const P* pad_value_ptr, const RuntimeShape& output_shape,
T* output_data) {
- RuntimeShape ext_input_shape = RuntimeShape::ExtendedShape(4, input_shape);
- RuntimeShape ext_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
+ const RuntimeShape ext_input_shape =
+ RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_output_shape =
+ RuntimeShape::ExtendedShape(4, output_shape);
TFLITE_DCHECK_LE(op_params.left_padding_count, 4);
TFLITE_DCHECK_LE(op_params.right_padding_count, 4);
@@ -3817,9 +3708,9 @@ inline void StridedSlice(const tflite::StridedSliceParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
// Reverse and pad to 4 dimensions because that is what the runtime code
@@ -3915,7 +3806,7 @@ template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
const RuntimeShape& output_shape, T* output_data) {
- RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
+ const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(4, input_shape);
// TODO(dkalenichenko): This op only supports 4D tensors or smaller.
TFLITE_DCHECK_LE(op_params.begin_count, 4);
TFLITE_DCHECK_LE(op_params.size_count, 4);
@@ -4141,9 +4032,9 @@ inline void Mean(const tflite::MeanParams& op_params,
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape input_shape =
+ const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int output_batch = output_shape.Dims(0);
@@ -4196,12 +4087,15 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis for quantized values.
template <typename T, typename U>
-inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
- const int* input_dims, const int input_num_dims,
- T* output_data, int32 output_zero_point, float output_scale,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis, U* temp_sum) {
+inline bool QuantizedMeanOrSum(const T* input_data, int32 input_zero_point,
+ float input_scale, const int* input_dims,
+ const int input_num_dims, T* output_data,
+ int32 output_zero_point, float output_scale,
+ const int* output_dims,
+ const int output_num_dims, const int* axis,
+ const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum,
+ bool compute_sum) {
// Reset output data.
size_t num_outputs = 1;
for (int idx = 0; idx < output_num_dims; ++idx) {
@@ -4243,14 +4137,24 @@ inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
if (num_elements_in_axis > 0) {
const float scale = input_scale / output_scale;
- const float bias = -input_zero_point * scale;
- for (size_t idx = 0; idx < num_outputs; ++idx) {
- float float_mean = static_cast<float>(temp_sum[idx]) /
- static_cast<float>(num_elements_in_axis);
-
- // Convert to float value.
- output_data[idx] =
- static_cast<T>(round(float_mean * scale + bias)) + output_zero_point;
+ if (compute_sum) {
+ // TODO(b/116341117): Eliminate float and do this completely in 8bit.
+ const float bias = -input_zero_point * scale * num_elements_in_axis + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ const U value = static_cast<U>(round(temp_sum[idx] * scale + bias)) +
+ output_zero_point;
+ output_data[idx] = static_cast<T>(value);
+ }
+ } else {
+ const float bias = -input_zero_point * scale + 0.5;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ float float_mean = static_cast<float>(temp_sum[idx]) /
+ static_cast<float>(num_elements_in_axis);
+
+ // Convert to float value.
+ output_data[idx] = static_cast<T>(round(float_mean * scale + bias)) +
+ output_zero_point;
+ }
}
}
return true;
@@ -4268,6 +4172,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -4280,6 +4194,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T, typename Op>
void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
@@ -4290,7 +4214,7 @@ void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -4355,50 +4279,105 @@ void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
std::greater<T1>());
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T1, typename T2, typename T3>
+inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const RuntimeShape& input2_shape, const T3* input2_data,
+ const RuntimeShape& output_shape, T2* output_data) {
+ // Drop shape of second input: not needed.
+ ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
-void Transpose(const T* input, const Dims<4>& input_dims, T* output,
- const Dims<4>& output_dims, const int* permuted_axes) {
+void Transpose(const TransposeParams& params,
+ const RuntimeShape& unextended_input_shape, const T* input_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ const int unextended_output_size = unextended_output_shape.DimensionsCount();
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_size, 4);
+ TFLITE_DCHECK_EQ(unextended_output_size, params.perm_count);
+ const RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+ const int input_ext_size = 4 - unextended_input_shape.DimensionsCount();
+ const int output_ext_size = 4 - unextended_output_size;
+
+ // The perm data is extended to match the output, each index incremented by
+ // the amount of front padding of the input shape.
+ int extended_perm[4];
+ for (int i = 0; i < output_ext_size; ++i) {
+ extended_perm[i] = i;
+ }
+ for (int i = 0; i < unextended_output_size; ++i) {
+ extended_perm[i + output_ext_size] = params.perm[i] + input_ext_size;
+ }
+
int out_sizes[4];
// Compute the inverse permutation array so we can do an output centered
// transpose. Also, check to make sure output_dims is matching input_dims.
for (int k = 0; k < 4; k++) {
- out_sizes[k] =
- MatchingArraySize(input_dims, permuted_axes[k], output_dims, k);
+ out_sizes[k] = MatchingDim(input_shape, extended_perm[k], output_shape, k);
}
// Naive transpose loop (iterate on output index and compute input index).
int o[4]; // loop index (on output).
int i[4];
for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) {
- i[permuted_axes[3]] = o[3];
+ i[extended_perm[3]] = o[3];
for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) {
- i[permuted_axes[2]] = o[2];
+ i[extended_perm[2]] = o[2];
for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) {
- i[permuted_axes[1]] = o[1];
+ i[extended_perm[1]] = o[1];
for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) {
- i[permuted_axes[0]] = o[0];
- output[Offset(output_dims, o)] = input[Offset(input_dims, i)];
+ i[extended_perm[0]] = o[0];
+ output_data[Offset(output_shape, o)] =
+ input_data[Offset(input_shape, i)];
}
}
}
}
}
-inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
- const float* filter_data, const Dims<4>& filter_dims,
- int stride_width, int stride_height, int pad_width,
- int pad_height, float* output_data,
- const Dims<4>& output_dims, float* /*im2col_data*/,
- const Dims<4>& /*im2col_dims*/) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0);
- const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int filter_height = ArraySize(filter_dims, 2);
- const int filter_width = ArraySize(filter_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+void Transpose(const T* input, const Dims<4>& input_dims, T* output,
+ const Dims<4>& output_dims, const int* permuted_axes) {
+ TransposeParams params;
+ params.perm_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ params.perm[i] = 3 - permuted_axes[3 - i];
+ }
+ Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
+ output);
+}
+
+inline void TransposeConv(
+ const ConvParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& filter_shape,
+ const float* filter_data, const RuntimeShape& output_shape,
+ float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ (void)im2col_data; // only used in optimized code.
+ (void)im2col_shape; // only used in optimized code.
+
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
+ const int output_depth = MatchingDim(filter_shape, 0, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
// Although transpose convolution simplifies to convolution with transposed
// weights for strides of 1, non-unitary striding complicates matters. To
@@ -4407,7 +4386,7 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
// computing their influence on the output, rather than looping through the
// output elements in the typical "gather" access pattern of a conv. We
// therefore must initialize the output array to zero.
- const int num_elements = FlatSize(output_dims);
+ const int num_elements = output_shape.FlatSize();
for (int i = 0; i < num_elements; i++) {
output_data[i] = 0.0f;
}
@@ -4430,13 +4409,14 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
// We cannot accumulate out of bounds
if ((out_x >= 0) && (out_x < output_width) && (out_y >= 0) &&
(out_y < output_height)) {
- float input_value = input_data[Offset(input_dims, in_channel,
- in_x, in_y, batch)];
+ float input_value = input_data[Offset(
+ input_shape, batch, in_y, in_x, in_channel)];
float filter_value =
- filter_data[Offset(filter_dims, in_channel, filter_x,
- filter_y, out_channel)];
- output_data[Offset(output_dims, out_channel, out_x, out_y,
- batch)] += input_value * filter_value;
+ filter_data[Offset(filter_shape, out_channel, filter_y,
+ filter_x, in_channel)];
+ output_data[Offset(output_shape, batch, out_y, out_x,
+ out_channel)] +=
+ input_value * filter_value;
}
}
}
@@ -4447,6 +4427,27 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
+ const float* filter_data, const Dims<4>& filter_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* output_data,
+ const Dims<4>& output_dims, float* im2col_data,
+ const Dims<4>& im2col_dims) {
+ tflite::ConvParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = pad_width;
+ op_params.padding_values.height = pad_height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
+ TransposeConv(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(filter_dims), filter_data, DimsToShape(output_dims),
+ output_data, DimsToShape(im2col_dims), im2col_data);
+}
+
template <typename T>
inline bool EqualFn(T lhs, T rhs) {
return lhs == rhs;
@@ -4557,9 +4558,11 @@ inline void Comparison(int left_shift, const T* input1_data,
op_params.left_shift = left_shift;
op_params.input1_offset = input1_offset;
op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.input1_shift = kReverseShift * input1_shift;
op_params.input2_offset = input2_offset;
op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.input2_shift = kReverseShift * input2_shift;
ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
@@ -4577,7 +4580,7 @@ inline void BroadcastComparison4DSlowImpl(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -4636,7 +4639,7 @@ inline void BroadcastComparison4DSlowWithScaling(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -4691,9 +4694,11 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
op_params.left_shift = left_shift;
op_params.input1_offset = input1_offset;
op_params.input1_multiplier = input1_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.input1_shift = kReverseShift * input1_shift;
op_params.input2_offset = input2_offset;
op_params.input2_multiplier = input2_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
op_params.input2_shift = kReverseShift * input2_shift;
BroadcastComparison4DSlowWithScaling<T, F>(
@@ -4797,47 +4802,81 @@ TFLITE_COMPARISON_OP(LessEqual);
#undef TFLITE_COMPARISON_OP
template <typename D, typename T>
+void Select(const RuntimeShape& input_condition_shape,
+ const D* input_condition_data, const RuntimeShape& input_x_shape,
+ const T* input_x_data, const RuntimeShape& input_y_shape,
+ const T* input_y_data, const RuntimeShape& output_shape,
+ T* output_data) {
+ const int64_t flatsize = MatchingFlatSize(
+ input_condition_shape, input_x_shape, input_y_shape, output_shape);
+ for (int64_t i = 0; i < flatsize; ++i) {
+ output_data[i] =
+ input_condition_data[i] ? input_x_data[i] : input_y_data[i];
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename D, typename T>
inline void Select(const D* input_condition_data,
const Dims<4>& input_condition_dims, const T* input_x_data,
const Dims<4>& input_x_dims, const T* input_y_data,
const Dims<4>& input_y_dims, T* output_data,
const Dims<4>& output_dims) {
- const int64_t flatsize =
- MatchingFlatSize(input_x_dims, input_y_dims, output_dims);
- for (int64_t i = 0; i < flatsize; ++i) {
- output_data[i] =
- input_condition_data[i] ? input_x_data[i] : input_y_data[i];
- }
+ Select(DimsToShape(input_condition_dims), input_condition_data,
+ DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
+ input_y_data, DimsToShape(output_dims), output_data);
}
template <typename D, typename T>
-inline void RankOneSelect(const D* input_condition_data,
- const Dims<4>& input_condition_dims,
- const T* input_x_data, const Dims<4>& input_x_dims,
- const T* input_y_data, const Dims<4>& input_y_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int64_t rank = MatchingArraySize(input_condition_dims, 0, input_x_dims,
- 3, input_y_dims, 3, output_dims, 3);
+void RankOneSelect(const RuntimeShape& input_condition_shape,
+ const D* input_condition_data,
+ const RuntimeShape& input_x_shape, const T* input_x_data,
+ const RuntimeShape& input_y_shape, const T* input_y_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ const int64_t outer_size = input_condition_shape.FlatSize();
+ TFLITE_DCHECK_EQ(
+ MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0),
+ outer_size);
const int64_t inner_size =
- MatchingFlatSizeSkipDim(input_x_dims, 3, input_y_dims, output_dims);
+ MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
int64_t offset = 0;
- for (int64_t i = 0; i < rank; i++) {
+ for (int64_t i = 0; i < outer_size; i++) {
const T* input_data = input_condition_data[i] ? input_x_data : input_y_data;
memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
offset += inner_size;
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename D, typename T>
+inline void RankOneSelect(const D* input_condition_data,
+ const Dims<4>& input_condition_dims,
+ const T* input_x_data, const Dims<4>& input_x_dims,
+ const T* input_y_data, const Dims<4>& input_y_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
+ DimsToShape(input_x_dims), input_x_data,
+ DimsToShape(input_y_dims), input_y_data,
+ DimsToShape(output_dims), output_data);
+}
+
// For easy implementation, the indices is always a vector of size-4 vectors.
template <typename T, typename TI>
inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
- const T* values, T default_value, T* output_data,
- const Dims<4>& output_dims, bool value_is_scalar) {
+ const T* values, T default_value,
+ bool value_is_scalar,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
const int value_count = indices.size();
// First fill the output_data with default value.
- const int num_elements = FlatSize(output_dims);
+ const int num_elements = output_shape.FlatSize();
for (int i = 0; i < num_elements; ++i) {
output_data[i] = default_value;
}
@@ -4849,8 +4888,8 @@ inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
const std::vector<TI>& index = indices[i];
TFLITE_DCHECK_EQ(index.size(), 4);
const T value = *values; // just use the first value.
- output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] =
- value;
+ output_data[Offset(output_shape, index[0], index[1], index[2],
+ index[3])] = value;
}
return;
}
@@ -4860,11 +4899,21 @@ inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
const std::vector<TI>& index = indices[i];
TFLITE_DCHECK_EQ(index.size(), 4);
const T value = values[i];
- output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] =
+ output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] =
value;
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, typename TI>
+inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
+ const T* values, T default_value, T* output_data,
+ const Dims<4>& output_dims, bool value_is_scalar) {
+ SparseToDense(indices, values, default_value, value_is_scalar,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T>
inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
@@ -4877,16 +4926,22 @@ inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
}
template <typename T>
-inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
+inline void BroadcastPow4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
- const RuntimeShape& input2_shape,
+ const RuntimeShape& unextended_input2_shape,
const T* input2_data,
- const RuntimeShape& output_shape,
+ const RuntimeShape& unextended_output_shape,
T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ const RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4923,7 +4978,7 @@ inline void BroadcastLogical4DSlow(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
@@ -4962,7 +5017,7 @@ inline void BroadcastBinaryFunction4DSlow(
TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
- RuntimeShape output_shape =
+ const RuntimeShape output_shape =
RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/softmax.h b/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
new file mode 100644
index 0000000000..006174e8db
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/reference/softmax.h
@@ -0,0 +1,202 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
+
+#include "fixedpoint/fixedpoint.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/round.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace reference_ops {
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+ for (int i = 0; i < outer_size; ++i) {
+ // Find max element value which we'll use to ensure numerical stability
+ // taking advantage of the following equality:
+ // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C))
+ float max = std::numeric_limits<float>::lowest();
+ for (int c = 0; c < depth; ++c) {
+ max = std::max(max, input_data[i * depth + c]);
+ }
+
+ // Compute sum.
+ float sum = 0.f;
+ for (int c = 0; c < depth; ++c) {
+ sum += std::exp((input_data[i * depth + c] - max) * params.beta);
+ }
+
+ // Compute result.
+ for (int c = 0; c < depth; ++c) {
+ output_data[i * depth + c] =
+ std::exp((input_data[i * depth + c] - max) * params.beta) / sum;
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_beta_multiplier = params.input_multiplier;
+ const int32 input_beta_left_shift = params.input_left_shift;
+ const int diff_min = params.diff_min;
+ // The representation chosen for the input to the exp() function is Q5.26.
+ // We need to leave extra space since values that we skip might be as large as
+ // -32 before multiplying by input_beta_multiplier, and therefore as large as
+ // -16 afterwards. Note that exp(-8) is definitely not insignificant to
+ // accumulation, but exp(-16) definitely is.
+ static const int kScaledDiffIntegerBits = 5;
+ static const int kAccumulationIntegerBits = 12;
+ using FixedPointScaledDiff =
+ gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
+ using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
+ using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
+
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
+
+ for (int i = 0; i < outer_size; ++i) {
+ uint8 max_in_row = 0;
+ for (int c = 0; c < depth; ++c) {
+ max_in_row = std::max(max_in_row, input_data[i * depth + c]);
+ }
+
+ FixedPointAccum sum_of_exps = FixedPointAccum::Zero();
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[i * depth + c]) - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+ sum_of_exps = sum_of_exps + gemmlowp::Rescale<kAccumulationIntegerBits>(
+ exp_on_negative_values(scaled_diff_f8));
+ }
+ }
+
+ int32 fixed_sum_of_exps = sum_of_exps.raw();
+ int headroom_plus_one =
+ CountLeadingZeros(static_cast<uint32>(fixed_sum_of_exps));
+ // This is the number of bits to the left of the binary point above 1.0.
+ // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and
+ // no later adjustment will be needed.
+ int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one;
+ int32 shifted_sum_minus_one = static_cast<int32>(
+ (static_cast<uint32>(fixed_sum_of_exps) << headroom_plus_one) -
+ (static_cast<uint32>(1) << 31));
+
+ FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1(
+ FixedPoint0::FromRaw(shifted_sum_minus_one));
+
+ for (int c = 0; c < depth; ++c) {
+ int32 input_diff =
+ static_cast<int32>(input_data[i * depth + c]) - max_in_row;
+ if (input_diff >= diff_min) {
+ const int32 input_diff_rescaled =
+ MultiplyByQuantizedMultiplierGreaterThanOne(
+ input_diff, input_beta_multiplier, input_beta_left_shift);
+ const FixedPointScaledDiff scaled_diff_f8 =
+ FixedPointScaledDiff::FromRaw(input_diff_rescaled);
+
+ FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8);
+ int32 unsat_output = gemmlowp::RoundingDivideByPOT(
+ (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8);
+
+ output_data[i * depth + c] = static_cast<uint8>(
+ std::max(std::min(unsat_output, static_cast<int32>(255)),
+ static_cast<int32>(0)));
+
+ } else {
+ output_data[i * depth + c] = 0;
+ }
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+// Performs softmax along the input of size (input_size * batch_size).
+inline void Softmax(const float* in, const int input_size, const int batch_size,
+ const float beta, float* out) {
+ // TF_LITE_ASSERT(input_size > 0);
+
+ // For each batch
+ for (int b = 0; b < batch_size; b++) {
+ // Find the max coeff.
+ float max_coeff = in[0];
+ for (int i = 1; i < input_size; i++) {
+ if (in[i] > max_coeff) max_coeff = in[i];
+ }
+
+ // Compute the normalized sum of exps.
+ float exp_sum = 0.0;
+ for (int i = 0; i < input_size; i++) {
+ out[i] = std::exp((in[i] - max_coeff) * beta);
+ exp_sum += out[i];
+ }
+
+ // Divide by the sum of exps.
+ float reciprocal_sum_exp = 1.f / exp_sum;
+ for (int i = 0; i < input_size; i++) {
+ out[i] *= reciprocal_sum_exp;
+ }
+
+ // Advance in and out pointers for the next batch.
+ in += input_size;
+ out += input_size;
+ }
+}
+
+} // namespace reference_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_SOFTMAX_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc
index 9b1fd9b344..75d568ae3a 100644
--- a/tensorflow/contrib/lite/kernels/internal/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc
@@ -19,41 +19,24 @@ limitations under the License.
namespace tflite {
-Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) {
- Dims<4> result;
- int cum_prod = 1;
-
- result.sizes[0] = depth;
- result.strides[0] = cum_prod;
- cum_prod *= result.sizes[0];
-
- result.sizes[1] = width;
- result.strides[1] = cum_prod;
- cum_prod *= result.sizes[1];
-
- result.sizes[2] = height;
- result.strides[2] = cum_prod;
- cum_prod *= result.sizes[2];
-
- result.sizes[3] = batch;
- result.strides[3] = cum_prod;
-
- return result;
-}
-
// this is a copied from an internal function in propagate_fixed_sizes.cc
-bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
- int filter_height, int stride, PaddingType padding_type,
- Dims<4>* output_dims, int* pad_width, int* pad_height) {
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int batch = ArraySize(input_dims, 3);
+bool ComputeConvSizes(const RuntimeShape& input_shape, int output_depth,
+ int filter_width, int filter_height, int stride,
+ int dilation_width_factor, int dilation_height_factor,
+ PaddingType padding_type, RuntimeShape* output_shape,
+ int* pad_width, int* pad_height) {
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int batch = input_shape.Dims(0);
+
+ int dilated_filter_width = dilation_width_factor * (filter_width - 1) + 1;
+ int dilated_filter_height = dilation_height_factor * (filter_height - 1) + 1;
int output_height = 0;
int output_width = 0;
if (padding_type == PaddingType::kValid) {
- output_height = (input_height + stride - filter_height) / stride;
- output_width = (input_width + stride - filter_width) / stride;
+ output_height = (input_height + stride - dilated_filter_height) / stride;
+ output_width = (input_width + stride - dilated_filter_width) / stride;
} else if (padding_type == PaddingType::kSame) {
output_height = (input_height + stride - 1) / stride;
output_width = (input_width + stride - 1) / stride;
@@ -65,11 +48,14 @@ bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
return false;
}
- *pad_height =
- ((output_height - 1) * stride + filter_height - input_height) / 2;
- *pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2;
- *output_dims =
- MakeDimsForInference(output_depth, output_width, output_height, batch);
+ *pad_height = std::max(
+ 0, ((output_height - 1) * stride + dilated_filter_height - input_height) /
+ 2);
+ *pad_width = std::max(
+ 0,
+ ((output_width - 1) * stride + dilated_filter_width - input_width) / 2);
+
+ output_shape->BuildFrom({batch, output_height, output_width, output_depth});
return true;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h
index 26078cef49..e4a383bedf 100644
--- a/tensorflow/contrib/lite/kernels/internal/test_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/test_util.h
@@ -26,13 +26,12 @@ limitations under the License.
namespace tflite {
-// Creates a Dims struct from a set of dimensions.
-Dims<4> MakeDimsForInference(int depth, int width, int height, int batch);
-
// Computes output and padding dimensions.
-bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width,
- int filter_height, int stride, PaddingType padding_type,
- Dims<4>* output_dims, int* pad_width, int* pad_height);
+bool ComputeConvSizes(const RuntimeShape& input_shape, int output_depth,
+ int filter_width, int filter_height, int stride,
+ int dilation_width_factor, int dilation_height_factor,
+ PaddingType padding_type, RuntimeShape* output_shape,
+ int* pad_width, int* pad_height);
// Returns a mt19937 random engine.
std::mt19937& RandomEngine();
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index c4c7cf3842..a3a5994c9c 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -15,9 +15,10 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
+#include <algorithm>
#include <cstring>
-#include <iterator>
+#include "absl/base/macros.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
@@ -26,8 +27,8 @@ enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
enum class PaddingType : uint8 { kNone, kSame, kValid };
struct PaddingValues {
- int8 width;
- int8 height;
+ int16 width;
+ int16 height;
};
// This enumeration allows for non-default formats for the weights array
@@ -125,7 +126,11 @@ class RuntimeShape {
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
if (dimensions_count > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
dims_pointer_ = new int32[dimensions_count];
+#endif // TF_LITE_STATIC_MEMORY
}
}
@@ -160,7 +165,11 @@ class RuntimeShape {
~RuntimeShape() {
if (size_ > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
delete[] dims_pointer_;
+#endif // TF_LITE_STATIC_MEMORY
}
}
@@ -179,20 +188,31 @@ class RuntimeShape {
dims_[i] = val;
}
}
+
inline int32* DimsData() {
return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
}
inline const int32* DimsData() const {
return size_ > kMaxSmallSize ? dims_pointer_ : dims_;
}
+ // The caller must ensure that the shape is no bigger than 4-D.
+ inline const int32* DimsDataUpTo4D() const { return dims_; }
inline void Resize(int dimensions_count) {
if (size_ > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
delete[] dims_pointer_;
+#endif // TF_LITE_STATIC_MEMORY
}
size_ = dimensions_count;
if (dimensions_count > kMaxSmallSize) {
+#ifdef TF_LITE_STATIC_MEMORY
+ TFLITE_CHECK(false && "No shape resizing supported on this platform");
+#else // TF_LITE_STATIC_MEMORY
dims_pointer_ = new int32[dimensions_count];
+#endif // TF_LITE_STATIC_MEMORY
}
}
@@ -283,6 +303,12 @@ inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
return result;
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
+ return RuntimeShape(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
if (num_dims == 0) {
@@ -340,11 +366,12 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
}
inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
- TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0));
- TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1));
- TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2));
- TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3));
- const int* dims_data = shape.DimsData();
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), 4);
+ const int* dims_data = shape.DimsDataUpTo4D();
+ TFLITE_DCHECK(i0 >= 0 && i0 < dims_data[0]);
+ TFLITE_DCHECK(i1 >= 0 && i1 < dims_data[1]);
+ TFLITE_DCHECK(i2 >= 0 && i2 < dims_data[2]);
+ TFLITE_DCHECK(i3 >= 0 && i3 < dims_data[3]);
return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
}
@@ -361,6 +388,10 @@ inline int Offset(const Dims<4>& dims, int* index) {
return Offset(dims, index[0], index[1], index[2], index[3]);
}
+inline int Offset(const RuntimeShape& shape, int* index) {
+ return Offset(shape, index[0], index[1], index[2], index[3]);
+}
+
// Get array size, DCHECKing that the dim index is in range.
//
// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
@@ -410,7 +441,7 @@ inline int FlatSize(const Dims<N>& dims) {
return flat_size;
}
-// Deprecated. Prefer FlatSize.
+ABSL_DEPRECATED("Prefer FlatSize.")
inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
return FlatSize(dims);
}
@@ -734,10 +765,10 @@ struct ConvParams {
PaddingType padding_type;
PaddingValues padding_values;
// TODO(starka): This was just "stride", so check that width+height is OK.
- int8 stride_width;
- int8 stride_height;
- int8 dilation_width_factor;
- int8 dilation_height_factor;
+ int16 stride_width;
+ int16 stride_height;
+ int16 dilation_width_factor;
+ int16 dilation_height_factor;
// uint8 inference params.
// TODO(b/65838351): Use smaller types if appropriate.
int32 input_offset;
@@ -745,8 +776,12 @@ struct ConvParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
};
struct DepthToSpaceParams {
@@ -756,8 +791,11 @@ struct DepthToSpaceParams {
struct DepthwiseParams {
PaddingType padding_type;
PaddingValues padding_values;
- int8 stride;
- int8 depth_multiplier;
+ int16 stride_width;
+ int16 stride_height;
+ int16 dilation_width_factor;
+ int16 dilation_height_factor;
+ int16 depth_multiplier;
// uint8 inference params.
// TODO(b/65838351): Use smaller types if appropriate.
int32 input_offset;
@@ -765,8 +803,12 @@ struct DepthwiseParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
};
struct DequantizationParams {
@@ -787,13 +829,17 @@ struct FullyConnectedParams {
int32 output_offset;
int32 output_multiplier;
int output_shift;
- int32 output_activation_min;
- int32 output_activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
FullyConnectedWeightsFormat weights_format;
};
struct GatherParams {
- int8 input_rank;
+ int16 input_rank;
int16 axis;
};
@@ -873,8 +919,8 @@ struct SoftmaxParams {
// for LogSoftmax.
double beta;
// uint8 inference params. Used even when beta defaults to 1.0.
- int32 input_beta_multiplier;
- int32 input_beta_left_shift;
+ int32 input_multiplier;
+ int32 input_left_shift;
// Reverse scaling is only used by LogSoftmax.
int32 reverse_scaling_divisor;
int32 reverse_scaling_right_shift;
@@ -924,6 +970,11 @@ struct TanhParams {
int input_left_shift;
};
+struct TransposeParams {
+ int8 perm_count;
+ int32 perm[4];
+};
+
template <typename P>
inline void SetActivationParams(float min, float max, P* params) {
params->float_activation_min = min;
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc
index 08f942c933..503ef28459 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.cc
+++ b/tensorflow/contrib/lite/kernels/kernel_util.cc
@@ -107,6 +107,9 @@ bool HaveSameShapes(const TfLiteTensor* input1, const TfLiteTensor* input2) {
return TfLiteIntArrayEqual(input1->dims, input2->dims);
}
+// TODO(petewarden): Having macros around this is ugly, look at other strategies
+// before replicating this approach elsewhere.
+#ifndef TF_LITE_STATIC_MEMORY
TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
const TfLiteTensor* input1,
const TfLiteTensor* input2,
@@ -125,5 +128,6 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
*output_shape = shape.release();
return kTfLiteOk;
}
+#endif // TF_LITE_STATIC_MEMORY
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
index 1bbea67b93..9739fd4514 100644
--- a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
@@ -16,7 +16,7 @@ limitations under the License.
// Layer Normalization LSTM op that applies normalization by mean and standard
// deviation to the activation of the LSTM layers. Please see
// https://arxiv.org/abs/1607.06450 for details.
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
index abc229f85a..479f6a7d3c 100644
--- a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index aaa3ce966e..5b996d00bc 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -893,18 +893,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
activation_out->type == kTfLiteFloat32 &&
concat_temp->type == kTfLiteFloat32 &&
activation_temp->type == kTfLiteFloat32) {
+ tflite::LstmCellParams op_params;
+ // Float LSTM cell does not need parameters to be set: leave untouched.
optimized_ops::LstmCell(
+ op_params,
// Inputs.
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(prev_activation), GetTensorDims(prev_activation),
- GetTensorData<float>(weights), GetTensorDims(weights),
- GetTensorData<float>(bias), GetTensorDims(bias),
- GetTensorData<float>(prev_state), GetTensorDims(prev_state),
+ GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(prev_activation), GetTensorData<float>(prev_activation),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(bias), GetTensorData<float>(bias),
+ GetTensorShape(prev_state), GetTensorData<float>(prev_state),
// Outputs.
- GetTensorData<float>(state_out), GetTensorDims(state_out),
- GetTensorData<float>(activation_out), GetTensorDims(activation_out),
- GetTensorData<float>(concat_temp), GetTensorDims(concat_temp),
- GetTensorData<float>(activation_temp), GetTensorDims(activation_temp));
+ GetTensorShape(state_out), GetTensorData<float>(state_out),
+ GetTensorShape(activation_out), GetTensorData<float>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
+ GetTensorShape(activation_temp), GetTensorData<float>(activation_temp));
} else if (input->type == kTfLiteUInt8 &&
prev_activation->type == kTfLiteUInt8 &&
weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
@@ -934,20 +937,25 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
int accum_shift;
tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
&accum_shift);
+ tflite::LstmCellParams op_params;
+ op_params.weights_zero_point = weights->params.zero_point;
+ op_params.accum_multiplier = accum_multiplier;
+ op_params.accum_shift = accum_shift;
optimized_ops::LstmCell<4>(
+ op_params,
// Inputs.
- GetTensorData<uint8_t>(input), GetTensorDims(input),
- GetTensorData<uint8_t>(prev_activation), GetTensorDims(prev_activation),
- GetTensorData<uint8_t>(weights), GetTensorDims(weights),
- GetTensorData<int32_t>(bias), GetTensorDims(bias),
- GetTensorData<int16_t>(prev_state), GetTensorDims(prev_state),
+ GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(prev_activation),
+ GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights),
+ GetTensorData<uint8_t>(weights), GetTensorShape(bias),
+ GetTensorData<int32_t>(bias), GetTensorShape(prev_state),
+ GetTensorData<int16_t>(prev_state),
// Outputs.
- GetTensorData<int16_t>(state_out), GetTensorDims(state_out),
- GetTensorData<uint8_t>(activation_out), GetTensorDims(activation_out),
- GetTensorData<uint8_t>(concat_temp), GetTensorDims(concat_temp),
- GetTensorData<int16_t>(activation_temp), GetTensorDims(activation_temp),
- weights->params.zero_point, accum_multiplier, accum_shift,
- gemm_context);
+ GetTensorShape(state_out), GetTensorData<int16_t>(state_out),
+ GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
+ GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
+ GetTensorShape(activation_temp),
+ GetTensorData<int16_t>(activation_temp), gemm_context);
} else {
context->ReportError(context,
"Unsupported combination of data types for LstmCell");
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 66cf147d75..5153ce5634 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
diff --git a/tensorflow/contrib/lite/kernels/mfcc_test.cc b/tensorflow/contrib/lite/kernels/mfcc_test.cc
index c9124adcaf..fe69223222 100644
--- a/tensorflow/contrib/lite/kernels/mfcc_test.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h
index d66364c4d8..11e814daee 100644
--- a/tensorflow/contrib/lite/kernels/op_macros.h
+++ b/tensorflow/contrib/lite/kernels/op_macros.h
@@ -15,17 +15,55 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_MACROS_H_
+// If we're on a platform without standard IO functions, fall back to a
+// non-portable function.
+#ifdef TF_LITE_MCU_DEBUG_LOG
+
+// This header is pulled in from the support library at
+// https://github.com/google/stm32_bare_lib
+#include <debug_log.h>
+
+#define DEBUG_LOG(x) \
+ do { \
+ DebugLog(x); \
+ } while (0)
+
+inline void InfiniteLoop() {
+ DEBUG_LOG("HALTED\n");
+ while (1) {
+ }
+}
+#define TFLITE_ASSERT_FALSE InfiniteLoop();
+#define TFLITE_ABORT InfiniteLoop();
+
+#else // TF_LITE_MCU_DEBUG_LOG
+
+#include <cassert>
#include <cstdio>
+#include <cstdlib>
-#define TF_LITE_FATAL(msg) \
- do { \
- fprintf(stderr, "%s\n", (msg)); \
- exit(1); \
+#define DEBUG_LOG(x) \
+ do { \
+ fprintf(stderr, "%s", (x)); \
} while (0)
+
+#define TFLITE_ASSERT_FALSE assert(false)
+#define TFLITE_ABORT abort()
+
+#endif // TF_LITE_MCU_DEBUG_LOG
+
+#define TF_LITE_FATAL(msg) \
+ do { \
+ DEBUG_LOG(msg); \
+ DEBUG_LOG("\nFATAL\n"); \
+ TFLITE_ABORT; \
+ } while (0)
+
#define TF_LITE_ASSERT(x) \
do { \
if (!(x)) TF_LITE_FATAL(#x); \
} while (0)
+
#define TF_LITE_ASSERT_EQ(x, y) \
do { \
if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index d94d821e87..4732a37a65 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -215,7 +215,7 @@ TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
return PrepareSimple(context, node);
}
-TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) {
+TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
// reduce_mean requires a buffer to store intermediate sum result.
@@ -274,7 +274,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
} else {
TF_LITE_ENSURE(
context,
- reference_ops::Mean<>(
+ reference_ops::QuantizedMeanOrSum<>(
GetTensorData<uint8_t>(op_context.input),
op_context.input->params.zero_point,
op_context.input->params.scale, op_context.input->dims->data,
@@ -286,7 +286,7 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis),
- GetTensorData<int>(temp_sum)));
+ GetTensorData<int>(temp_sum), /*compute_sum=*/false));
}
break;
default:
@@ -416,19 +416,57 @@ TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
}
}
+TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
+ OpContext op_context(context, node);
+ const auto& input = op_context.input;
+ const auto& output = op_context.output;
+ if (input->type != kTfLiteUInt8 ||
+ (input->params.scale == output->params.scale &&
+ input->params.zero_point == output->params.zero_point)) {
+ return EvalGeneric<kReference, kSum>(context, node);
+ } else {
+ // Rescaling 8bit reduce sum.
+ int num_axis = static_cast<int>(NumElements(op_context.axis));
+ TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
+ TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* temp_sum = GetTemporary(context, node, /*index=*/2);
+ // Resize the output tensor if the output tensor is dynamic.
+ if (IsDynamicTensor(op_context.output)) {
+ TF_LITE_ENSURE_OK(context,
+ ResizeTempAxis(context, &op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ TF_LITE_ENSURE_OK(context, ResizeTempSum(context, &op_context, temp_sum));
+ }
+
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::QuantizedMeanOrSum<>(
+ GetTensorData<uint8_t>(op_context.input),
+ op_context.input->params.zero_point, op_context.input->params.scale,
+ op_context.input->dims->data, op_context.input->dims->size,
+ GetTensorData<uint8_t>(op_context.output),
+ op_context.output->params.zero_point,
+ op_context.output->params.scale, op_context.output->dims->data,
+ op_context.output->dims->size, GetTensorData<int>(op_context.axis),
+ num_axis, op_context.params->keep_dims,
+ GetTensorData<int>(temp_index), GetTensorData<int>(resolved_axis),
+ GetTensorData<int32>(temp_sum), /*compute_sum=*/true));
+ }
+
+ return kTfLiteOk;
+}
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareMean,
+ reduce::PrepareMeanOrSum,
reduce::EvalMean<reduce::kReference>};
return &r;
}
TfLiteRegistration* Register_SUM_REF() {
- static TfLiteRegistration r = {
- reduce::Init, reduce::Free, reduce::PrepareSimple,
- reduce::EvalGeneric<reduce::kReference, reduce::kSum>};
+ static TfLiteRegistration r = {reduce::Init, reduce::Free,
+ reduce::PrepareMeanOrSum, reduce::EvalSum};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
index 6d289b14d8..fb2ec58ab2 100644
--- a/tensorflow/contrib/lite/kernels/reduce_test.cc
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -488,6 +488,18 @@ TEST(ConstUint8SumOpTest, NotKeepDims) {
ArrayFloatNear({-0.823529, -0.815686}, kQuantizedTolerance)));
}
+TEST(ConstUint8SumOpTest, NotKeepDimsRescaling) {
+ float kQuantizedTolerance = GetTolerance(0.0, 2.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ SumOpConstModel m({TensorType_UINT8, {1, 3, 2}, 0.0, 1.0},
+ {TensorType_UINT8, {2}, 0.0, 2.0}, {1}, {1}, false);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear(
+ {1.2, 1.2}, kQuantizedTolerance)));
+}
+
TEST(ConstUint8SumOpTest, KeepDims) {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index c66959fdf4..2f4b663a28 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -118,6 +118,8 @@ TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT();
TfLiteRegistration* Register_UNPACK();
TfLiteRegistration* Register_FLOOR_DIV();
+TfLiteRegistration* Register_SQUARE();
+TfLiteRegistration* Register_ZEROS_LIKE();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@@ -156,7 +158,9 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D());
- AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
AddBuiltin(BuiltinOperator_RNN, Register_RNN());
AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
@@ -243,6 +247,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
+ AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
+ AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc
index c1e0149c20..b1d25a9f50 100644
--- a/tensorflow/contrib/lite/kernels/relu1_test.cc
+++ b/tensorflow/contrib/lite/kernels/relu1_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc
index 9156917140..0fdb0a3935 100644
--- a/tensorflow/contrib/lite/kernels/test_util.cc
+++ b/tensorflow/contrib/lite/kernels/test_util.cc
@@ -74,8 +74,8 @@ void SingleOpModel::SetCustomOp(
CustomOptionsFormat_FLEXBUFFERS));
}
-void SingleOpModel::BuildInterpreter(
- std::vector<std::vector<int>> input_shapes) {
+void SingleOpModel::BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+ bool allow_fp32_relax_to_fp16) {
auto opcodes = builder_.CreateVector(opcodes_);
auto operators = builder_.CreateVector(operators_);
auto tensors = builder_.CreateVector(tensors_);
@@ -113,6 +113,8 @@ void SingleOpModel::BuildInterpreter(
CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk);
}
+ interpreter_->SetAllowFp16PrecisionForFp32(allow_fp32_relax_to_fp16);
+
// Modify delegate with function.
if (apply_delegate_fn_) {
apply_delegate_fn_(interpreter_.get());
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index bedbe93ae6..84deb0e0e8 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -182,7 +182,8 @@ class SingleOpModel {
// Build the interpreter for this model. Also, resize and allocate all
// tensors given the shapes of the inputs.
- void BuildInterpreter(std::vector<std::vector<int>> input_shapes);
+ void BuildInterpreter(std::vector<std::vector<int>> input_shapes,
+ bool allow_fp32_relax_to_fp16 = false);
void Invoke();
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index 6f2d98ede8..1c4a5ee91d 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -69,7 +69,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 4);
- // Currenlty only supports float32.
+ // Currently only supports float32.
const TfLiteType data_type = input->type;
TF_LITE_ENSURE(context, data_type == kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, output->type, data_type);
@@ -117,19 +117,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Currently only support float32.
switch (input->type) {
- case kTfLiteFloat32:
+ case kTfLiteFloat32: {
+ tflite::ConvParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = padding_size.width;
+ op_params.padding_values.height = padding_size.height;
+ op_params.stride_width = stride_width;
+ op_params.stride_height = stride_height;
+
reference_ops::TransposeConv(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(weights), GetTensorDims(weights), stride_width,
- stride_height, padding_size.width, padding_size.height,
- GetTensorData<float>(output), GetTensorDims(output),
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(weights), GetTensorData<float>(weights),
+ GetTensorShape(output), GetTensorData<float>(output),
// Last two args specify im2col which reference_ops ignores.
// (Note this does not lead to a performance regression, as the
// previous optimized version was just a copy of the reference code.)
// TODO(b/110208176): Allocate im2col tensors and switch to
// optimized_ops.
- GetTensorData<float>(output), GetTensorDims(output));
+ GetTensorShape(output), GetTensorData<float>(output));
break;
+ }
default:
context->ReportError(context, "Type %d, not currently supported.",
input->type);
diff --git a/tensorflow/contrib/lite/kernels/zeros_like.cc b/tensorflow/contrib/lite/kernels/zeros_like.cc
new file mode 100644
index 0000000000..cce5240a9b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/zeros_like.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace zeros_like {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ output->type = input->type;
+
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const int num_elements = NumElements(input);
+ switch (input->type) {
+ case kTfLiteInt64:
+ memset(GetTensorData<int64_t>(output), 0, num_elements * sizeof(int64_t));
+ break;
+ case kTfLiteInt32:
+ memset(GetTensorData<int32_t>(output), 0, num_elements * sizeof(int32_t));
+ break;
+ case kTfLiteFloat32:
+ memset(GetTensorData<float>(output), 0, num_elements * sizeof(float));
+ break;
+ default:
+ context->ReportError(context,
+ "ZerosLike only currently supports int64, int32, "
+ "and float32, got %d.",
+ input->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace zeros_like
+
+TfLiteRegistration* Register_ZEROS_LIKE() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ zeros_like::Prepare, zeros_like::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/zeros_like_test.cc b/tensorflow/contrib/lite/kernels/zeros_like_test.cc
new file mode 100644
index 0000000000..d3382d1d5b
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/zeros_like_test.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ZerosLikeOpModel : public SingleOpModel {
+ public:
+ explicit ZerosLikeOpModel(const TensorData& input) {
+ input_ = AddInput(input);
+ output_ = AddOutput(input);
+ SetBuiltinOp(BuiltinOperator_ZEROS_LIKE, BuiltinOptions_ZerosLikeOptions,
+ CreateZerosLikeOptions(builder_).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ int input() { return input_; }
+ int output() { return output_; }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+TEST(ZerosLikeOpModel, ZerosLikeFloat) {
+ ZerosLikeOpModel m({TensorType_FLOAT32, {2, 3}});
+ m.PopulateTensor<float>(m.input(), {-2.0, -1.0, 0.0, 1.0, 2.0, 3.0});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<float>(m.output()),
+ ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({2, 3}));
+}
+
+TEST(ZerosLikeOpModel, ZerosLikeInt32) {
+ ZerosLikeOpModel m({TensorType_INT32, {1, 2, 2, 1}});
+ m.PopulateTensor<int32_t>(m.input(), {-2, -1, 0, 3});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int32_t>(m.output()),
+ ElementsAreArray({0, 0, 0, 0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 2, 1}));
+}
+
+TEST(ZerosLikeOpModel, ZerosLikeInt64) {
+ ZerosLikeOpModel m({TensorType_INT64, {1, 2, 2, 1}});
+ m.PopulateTensor<int64_t>(m.input(), {-2, -1, 0, 3});
+ m.Invoke();
+ EXPECT_THAT(m.ExtractVector<int64_t>(m.output()),
+ ElementsAreArray({0, 0, 0, 0}));
+ EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 2, 2, 1}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 241865b3d8..ea2817beec 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -177,6 +177,11 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
namespace {
template <class T>
std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
+ // Initialize shape of tensors with null shape. Empty vectors are converted
+ // to nullptr for models that are constructed via flatbuffers::Pack.
+ if (flat_array == nullptr) {
+ return {};
+ }
std::vector<int> ret(flat_array->Length());
for (int i = 0; i < flat_array->Length(); i++) {
ret[i] = flat_array->Get(i);
@@ -184,6 +189,13 @@ std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
return ret;
}
+// Used to determine how the op data parsing function creates its working space.
+class MallocDataAllocator : public BuiltinDataAllocator {
+ public:
+ void* Allocate(size_t size) override { return malloc(size); }
+ void Deallocate(void* data) override { free(data); }
+};
+
} // namespace
TfLiteStatus InterpreterBuilder::ParseNodes(
@@ -229,8 +241,9 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
op->custom_options()->size(), nullptr, registration);
} else {
void* builtin_data = nullptr;
- TF_LITE_ENSURE_STATUS(
- ParseOpData(op, op_type, error_reporter_, &builtin_data));
+ MallocDataAllocator malloc_allocator;
+ TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
+ &malloc_allocator, &builtin_data));
interpreter->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data,
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc
index 8ee63d2a02..a36404399b 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver.cc
@@ -30,10 +30,11 @@ const TfLiteRegistration* MutableOpResolver::FindOp(const char* op,
}
void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
- TfLiteRegistration* registration,
+ const TfLiteRegistration* registration,
int min_version, int max_version) {
for (int version = min_version; version <= max_version; ++version) {
TfLiteRegistration new_registration = *registration;
+ new_registration.custom_name = nullptr;
new_registration.builtin_code = op;
new_registration.version = version;
auto op_key = std::make_pair(op, version);
@@ -42,15 +43,27 @@ void MutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
}
void MutableOpResolver::AddCustom(const char* name,
- TfLiteRegistration* registration,
+ const TfLiteRegistration* registration,
int min_version, int max_version) {
for (int version = min_version; version <= max_version; ++version) {
TfLiteRegistration new_registration = *registration;
new_registration.builtin_code = BuiltinOperator_CUSTOM;
+ new_registration.custom_name = name;
new_registration.version = version;
auto op_key = std::make_pair(name, version);
custom_ops_[op_key] = new_registration;
}
}
+void MutableOpResolver::AddAll(const MutableOpResolver& other) {
+ // map::insert does not replace existing elements, and map::insert_or_assign
+ // wasn't added until C++17.
+ for (const auto& other_builtin : other.builtins_) {
+ builtins_[other_builtin.first] = other_builtin.second;
+ }
+ for (const auto& other_custom_op : other.custom_ops_) {
+ custom_ops_[other_custom_op.first] = other_custom_op.second;
+ }
+}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
index c319041e9b..efd6cfac2a 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver.h
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -57,10 +57,12 @@ class MutableOpResolver : public OpResolver {
const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
int version) const override;
const TfLiteRegistration* FindOp(const char* op, int version) const override;
- void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
- void AddCustom(const char* name, TfLiteRegistration* registration,
+ void AddBuiltin(tflite::BuiltinOperator op,
+ const TfLiteRegistration* registration, int min_version = 1,
+ int max_version = 1);
+ void AddCustom(const char* name, const TfLiteRegistration* registration,
int min_version = 1, int max_version = 1);
+ void AddAll(const MutableOpResolver& other);
private:
typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
diff --git a/tensorflow/contrib/lite/mutable_op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
index db690eaab9..b70c703839 100644
--- a/tensorflow/contrib/lite/mutable_op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -36,6 +36,20 @@ TfLiteRegistration* GetDummyRegistration() {
return &registration;
}
+TfLiteStatus Dummy2Invoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteRegistration* GetDummy2Registration() {
+ static TfLiteRegistration registration = {
+ .init = nullptr,
+ .free = nullptr,
+ .prepare = nullptr,
+ .invoke = Dummy2Invoke,
+ };
+ return &registration;
+}
+
TEST(MutableOpResolverTest, FinOp) {
MutableOpResolver resolver;
resolver.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
@@ -119,6 +133,26 @@ TEST(MutableOpResolverTest, FindCustomOpWithUnsupportedVersion) {
EXPECT_EQ(found_registration, nullptr);
}
+TEST(MutableOpResolverTest, AddAll) {
+ MutableOpResolver resolver1;
+ resolver1.AddBuiltin(BuiltinOperator_ADD, GetDummyRegistration());
+ resolver1.AddBuiltin(BuiltinOperator_MUL, GetDummy2Registration());
+
+ MutableOpResolver resolver2;
+ resolver2.AddBuiltin(BuiltinOperator_SUB, GetDummyRegistration());
+ resolver2.AddBuiltin(BuiltinOperator_ADD, GetDummy2Registration());
+
+ // resolver2's ADD op should replace resolver1's ADD op, while augmenting
+ // non-overlapping ops.
+ resolver1.AddAll(resolver2);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_ADD, 1)->invoke,
+ GetDummy2Registration()->invoke);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_MUL, 1)->invoke,
+ GetDummy2Registration()->invoke);
+ ASSERT_EQ(resolver1.FindOp(BuiltinOperator_SUB, 1)->invoke,
+ GetDummyRegistration()->invoke);
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index 81dd459223..687944023b 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -364,6 +364,9 @@ typedef int (*ANeuralNetworksModel_identifyInputsAndOutputs_fn)(
ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs,
uint32_t outputCount, const uint32_t* outputs);
+typedef int (*ANeuralNetworksModel_relaxComputationFloat32toFloat16_fn)(
+ ANeuralNetworksModel* model, bool allow);
+
typedef int (*ANeuralNetworksExecution_create_fn)(
ANeuralNetworksCompilation* compilation,
ANeuralNetworksExecution** execution);
@@ -656,6 +659,34 @@ inline int ANeuralNetworksModel_identifyInputsAndOutputs(
}
/**
+ * Specifies whether {@link ANEURALNETWORKS_TENSOR_FLOAT32} is allowed to be
+ * calculated with range and/or precision as low as that of the IEEE 754 16-bit
+ * floating-point format. By default, {@link ANEURALNETWORKS_TENSOR_FLOAT32}
+ * must be calculated using at least the range and precision of the IEEE 754
+ * 32-bit floating-point format.
+ *
+ * @param model The model to be modified.
+ * @param allow 'true' indicates {@link ANEURALNETWORKS_TENSOR_FLOAT32} may be
+ * calculated with range and/or precision as low as that of the
+ * IEEE 754 16-bit floating point format. 'false' indicates
+ * {@link ANEURALNETWORKS_TENSOR_FLOAT32} must be calculated using
+ * at least the range and precision of the IEEE 754 32-bit floating
+ * point format.
+ *
+ * Attempting to modify a model once {@link ANeuralNetworksModel_finish} has
+ * been called will return an error.
+ *
+ * Available since API level 28.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ */
+inline int ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+ ANeuralNetworksModel* model, bool allow) {
+ LOAD_FUNCTION(ANeuralNetworksModel_relaxComputationFloat32toFloat16);
+ EXECUTE_FUNCTION_RETURN(model, allow);
+}
+
+/**
* Create a {@link ANeuralNetworksCompilation} to compile the given model.
* This only creates the object. Compilation is only performed once
* {@link ANeuralNetworksCompilation_start} is invoked.
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 817486e898..f23a0ccb80 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -512,6 +512,10 @@ TfLiteStatus AddOpsAndParams(
nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
break;
case tflite::BuiltinOperator_RESHAPE:
+ if (node.inputs->size != 2) {
+ logError("NNAPI only supports 2-input RESHAPE");
+ return kTfLiteError;
+ }
nn_op_type = ANEURALNETWORKS_RESHAPE;
// add_reshape_params(node.builtin_data);
break;
@@ -672,6 +676,9 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_UNPACK:
case tflite::BuiltinOperator_FLOOR_DIV:
case tflite::BuiltinOperator_REDUCE_ANY:
+ case tflite::BuiltinOperator_SQUARE:
+ case tflite::BuiltinOperator_ZEROS_LIKE:
+ case tflite::BuiltinOperator_FILL:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
@@ -757,6 +764,11 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
reinterpret_cast<const uint32_t*>(augmented_inputs.data()),
static_cast<uint32_t>(augmented_outputs.size()),
reinterpret_cast<const uint32_t*>(augmented_outputs.data())));
+
+ if (GetAndroidSdkVersionCached() >= 28) {
+ CHECK_NN(ANeuralNetworksModel_relaxComputationFloat32toFloat16(
+ nn_model_, interpreter->GetAllowFp16PrecisionForFp32()));
+ }
CHECK_NN(ANeuralNetworksModel_finish(nn_model_));
}
if (!nn_compiled_model_) {
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
index f1f025f777..64ba2d8baa 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.cc
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -25,7 +25,7 @@ void PrintIntVector(const std::vector<int>& v) {
void PrintTfLiteIntVector(const TfLiteIntArray* v) {
if (!v) {
- printf(" (null)");
+ printf(" (null)\n");
return;
}
for (int k = 0; k < v->size; k++) {
@@ -99,8 +99,12 @@ void PrintInterpreterState(Interpreter* interpreter) {
interpreter->node_and_registration(node_index);
const TfLiteNode& node = node_and_reg->first;
const TfLiteRegistration& reg = node_and_reg->second;
- printf("Node %3d Operator Builtin Code %3d\n", node_index,
- reg.builtin_code);
+ if (reg.custom_name != nullptr) {
+ printf("Node %3d Operator Custom Name %s\n", node_index, reg.custom_name);
+ } else {
+ printf("Node %3d Operator Builtin Code %3d\n", node_index,
+ reg.builtin_code);
+ }
printf(" Inputs:");
PrintTfLiteIntVector(node.inputs);
printf(" Outputs:");
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 1c5516ae7c..1f48a826d4 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import enum # pylint: disable=g-bad-import-order
+
import os as _os
import platform as _platform
import subprocess as _subprocess
@@ -30,7 +32,6 @@ from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util import deprecation
from tensorflow.python.util.lazy_loader import LazyLoader
-
# Lazy load since some of the performance benchmark skylark rules
# break dependencies.
_toco_python = LazyLoader(
@@ -52,6 +53,31 @@ if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
_toco_from_proto_bin = "toco_from_protos"
+class ConverterMode(enum.Enum):
+ """Enum class defining the converters available to generate TFLite models.
+
+ WARNING: Experimental interface, subject to change.
+ """
+ # Convert model using TOCO such that all ops are TensorFlow Lite native ops.
+ #
+ # This is the only supported mode for any models that contain operations that
+ # cannot be resolved in TensorFlow.
+ DEFAULT = "DEFAULT"
+
+ # Convert model using TOCO such that only unsupported operations are
+ # represented as TensorFlow ops.
+ # WARNING: Experimental interface, subject to change.
+ TOCO_EXTENDED = "TOCO_EXTENDED"
+
+ # Convert model using TOCO such that all operations are represented as
+ # TensorFlow ops.
+ # WARNING: Experimental interface, subject to change.
+ TOCO_EXTENDED_ALL = "TOCO_EXTENDED_ALL"
+
+ def __str__(self):
+ return self.value
+
+
def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
"""Convert `input_data_str` according to model and toco parameters.
@@ -128,7 +154,8 @@ def build_toco_convert_protos(input_tensors,
change_concat_input_ranges=False,
post_training_quantize=False,
dump_graphviz_dir=None,
- dump_graphviz_video=False):
+ dump_graphviz_video=False,
+ converter_mode=ConverterMode.DEFAULT):
"""Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -183,6 +210,8 @@ def build_toco_convert_protos(input_tensors,
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
+ converter_mode: Experimental flag, subject to change. ConverterMode
+ indicating which converter to use. (default ConverterMode.DEFAULT)
Returns:
model_flags, toco_flags: two protocol buffers describing the conversion
@@ -211,6 +240,11 @@ def build_toco_convert_protos(input_tensors,
if dump_graphviz_dir:
toco.dump_graphviz_dir = dump_graphviz_dir
toco.dump_graphviz_include_video = dump_graphviz_video
+ if converter_mode == ConverterMode.TOCO_EXTENDED:
+ toco.allow_eager_ops = True
+ elif converter_mode == ConverterMode.TOCO_EXTENDED_ALL:
+ toco.allow_eager_ops = True
+ toco.force_eager_ops = True
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
@@ -301,9 +335,8 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
Raises:
Defined in `build_toco_convert_protos`.
"""
- model_flags, toco_flags = build_toco_convert_protos(input_tensors,
- output_tensors,
- *args, **kwargs)
+ model_flags, toco_flags = build_toco_convert_protos(
+ input_tensors, output_tensors, *args, **kwargs)
data = toco_convert_protos(model_flags.SerializeToString(),
toco_flags.SerializeToString(),
input_data.SerializeToString())
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 44dfb97b84..2be24455d8 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -40,6 +40,7 @@ from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import ConverterMode
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
@@ -113,6 +114,8 @@ class TocoConverter(object):
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
+ converter_mode: Experimental flag, subject to change. ConverterMode
+ indicating which converter to use. (default ConverterMode.DEFAULT)
Example usage:
@@ -179,6 +182,7 @@ class TocoConverter(object):
self.post_training_quantize = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
+ self.converter_mode = ConverterMode.DEFAULT
# Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors():
@@ -389,6 +393,7 @@ class TocoConverter(object):
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
+ ConverterMode option is unsupported for the model.
"""
# Checks dimensions in input tensor.
if self._has_valid_tensors():
@@ -439,12 +444,18 @@ class TocoConverter(object):
# Converts model.
if self._has_valid_tensors():
+ converter_kwargs["converter_mode"] = self.converter_mode
result = _toco_convert_impl(
input_data=self._graph_def,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
**converter_kwargs)
else:
+ # Graphs without valid tensors cannot be loaded into tf.Session since they
+ # contain TFLite operation(s) that cannot be resolved in TensorFlow.
+ if self.converter_mode != ConverterMode.DEFAULT:
+ raise ValueError("This model can only be converted with the default "
+ "converter.")
result = _toco_convert_graph_def(
input_data=self._graph_def,
input_arrays_with_shape=self._input_arrays_with_shape,
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 3f8ea433ff..f112ed5cdd 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -402,6 +402,28 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Ensure that the quantized weights tflite model is smaller.
self.assertTrue(len(quantized_tflite) < len(float_tflite))
+ def testExtendedMode(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.converter_mode = lite.ConverterMode.TOCO_EXTENDED_ALL
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensures the model contains TensorFlow ops.
+ # TODO(nupurgarg): Check values once there is a Python delegate interface.
+ interpreter = Interpreter(model_content=tflite_model)
+ with self.assertRaises(RuntimeError) as error:
+ interpreter.allocate_tensors()
+ self.assertIn(
+ 'Regular TensorFlow ops are not supported by this interpreter. Make '
+ 'sure you invoke the Eager delegate before inference.',
+ str(error.exception))
+
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index cc08ed3fe9..c0ff7f37f9 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -140,8 +140,11 @@ def _convert_model(flags):
if flags.change_concat_input_ranges:
converter.change_concat_input_ranges = (
flags.change_concat_input_ranges == "TRUE")
+
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
+ if flags.converter_mode:
+ converter.converter_mode = flags.converter_mode
if flags.post_training_quantize:
converter.post_training_quantize = flags.post_training_quantize
@@ -363,6 +366,8 @@ def run_main(_):
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
+
+ # Permitted ops flags.
parser.add_argument(
"--allow_custom_ops",
action="store_true",
@@ -371,6 +376,12 @@ def run_main(_):
"created for any op that is unknown. The developer will need to "
"provide these to the TensorFlow Lite runtime with a custom "
"resolver. (default False)"))
+ parser.add_argument(
+ "--converter_mode",
+ type=lite.ConverterMode,
+ choices=list(lite.ConverterMode),
+ help=("Experimental flag, subject to change. ConverterMode indicating "
+ "which converter to use. (default ConverterMode.DEFAULT)"))
# Logging flags.
parser.add_argument(
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index 55bf2c48b9..d892466c7a 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -25,14 +25,18 @@ py_binary(
],
)
+# TODO(wvo): re-enable this test once latest FlatBuffers has landed.
+
py_test(
name = "upgrade_schema_test",
size = "small",
srcs = ["upgrade_schema_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "manual",
"no_oss",
"no_pip",
+ "notap",
],
deps = [
":upgrade_schema",
diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
index 11057203a8..22b4616ccb 100644
--- a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
+++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <fstream>
#include <gtest/gtest.h>
-#include "flatbuffers/flatc.h" // flatbuffers
+#include "flatbuffers/flatc.h" // TF:flatbuffers
#include "tensorflow/core/platform/platform.h"
#ifdef PLATFORM_GOOGLE
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index cf66403ec9..3da3188c3a 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -173,6 +173,9 @@ enum BuiltinOperator : byte {
REDUCE_MIN = 89,
FLOOR_DIV = 90,
REDUCE_ANY = 91,
+ SQUARE = 92,
+ ZEROS_LIKE = 93,
+ FILL = 94,
}
// Options for the builtin operators.
@@ -242,6 +245,9 @@ union BuiltinOptions {
LogicalNotOptions,
UnpackOptions,
FloorDivOptions,
+ SquareOptions,
+ ZerosLikeOptions,
+ FillOptions,
}
enum Padding : byte { SAME, VALID }
@@ -274,11 +280,15 @@ table Pool2DOptions {
}
table DepthwiseConv2DOptions {
+ // Parameters for DepthwiseConv version 1 or above.
padding:Padding;
stride_w:int;
stride_h:int;
depth_multiplier:int;
fused_activation_function:ActivationFunctionType;
+ // Parameters for DepthwiseConv version 2 or above.
+ dilation_w_factor:int = 1;
+ dilation_h_factor:int = 1;
}
table ConcatEmbeddingsOptions {
@@ -579,6 +589,15 @@ table UnpackOptions {
table FloorDivOptions {
}
+table SquareOptions {
+}
+
+table ZerosLikeOptions {
+}
+
+table FillOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 6d9630d75e..23ac8484de 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -226,6 +226,15 @@ struct UnpackOptionsT;
struct FloorDivOptions;
struct FloorDivOptionsT;
+struct SquareOptions;
+struct SquareOptionsT;
+
+struct ZerosLikeOptions;
+struct ZerosLikeOptionsT;
+
+struct FillOptions;
+struct FillOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -255,8 +264,8 @@ enum TensorType {
TensorType_MAX = TensorType_COMPLEX64
};
-inline TensorType (&EnumValuesTensorType())[9] {
- static TensorType values[] = {
+inline const TensorType (&EnumValuesTensorType())[9] {
+ static const TensorType values[] = {
TensorType_FLOAT32,
TensorType_FLOAT16,
TensorType_INT32,
@@ -270,8 +279,8 @@ inline TensorType (&EnumValuesTensorType())[9] {
return values;
}
-inline const char **EnumNamesTensorType() {
- static const char *names[] = {
+inline const char * const *EnumNamesTensorType() {
+ static const char * const names[] = {
"FLOAT32",
"FLOAT16",
"INT32",
@@ -383,12 +392,15 @@ enum BuiltinOperator {
BuiltinOperator_REDUCE_MIN = 89,
BuiltinOperator_FLOOR_DIV = 90,
BuiltinOperator_REDUCE_ANY = 91,
+ BuiltinOperator_SQUARE = 92,
+ BuiltinOperator_ZEROS_LIKE = 93,
+ BuiltinOperator_FILL = 94,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_REDUCE_ANY
+ BuiltinOperator_MAX = BuiltinOperator_FILL
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[91] {
- static BuiltinOperator values[] = {
+inline const BuiltinOperator (&EnumValuesBuiltinOperator())[94] {
+ static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
BuiltinOperator_CONCATENATION,
@@ -479,13 +491,16 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[91] {
BuiltinOperator_UNPACK,
BuiltinOperator_REDUCE_MIN,
BuiltinOperator_FLOOR_DIV,
- BuiltinOperator_REDUCE_ANY
+ BuiltinOperator_REDUCE_ANY,
+ BuiltinOperator_SQUARE,
+ BuiltinOperator_ZEROS_LIKE,
+ BuiltinOperator_FILL
};
return values;
}
-inline const char **EnumNamesBuiltinOperator() {
- static const char *names[] = {
+inline const char * const *EnumNamesBuiltinOperator() {
+ static const char * const names[] = {
"ADD",
"AVERAGE_POOL_2D",
"CONCATENATION",
@@ -578,6 +593,9 @@ inline const char **EnumNamesBuiltinOperator() {
"REDUCE_MIN",
"FLOOR_DIV",
"REDUCE_ANY",
+ "SQUARE",
+ "ZEROS_LIKE",
+ "FILL",
nullptr
};
return names;
@@ -655,12 +673,15 @@ enum BuiltinOptions {
BuiltinOptions_LogicalNotOptions = 63,
BuiltinOptions_UnpackOptions = 64,
BuiltinOptions_FloorDivOptions = 65,
+ BuiltinOptions_SquareOptions = 66,
+ BuiltinOptions_ZerosLikeOptions = 67,
+ BuiltinOptions_FillOptions = 68,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_FloorDivOptions
+ BuiltinOptions_MAX = BuiltinOptions_FillOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[66] {
- static BuiltinOptions values[] = {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
+ static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
BuiltinOptions_DepthwiseConv2DOptions,
@@ -726,13 +747,16 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[66] {
BuiltinOptions_LogicalAndOptions,
BuiltinOptions_LogicalNotOptions,
BuiltinOptions_UnpackOptions,
- BuiltinOptions_FloorDivOptions
+ BuiltinOptions_FloorDivOptions,
+ BuiltinOptions_SquareOptions,
+ BuiltinOptions_ZerosLikeOptions,
+ BuiltinOptions_FillOptions
};
return values;
}
-inline const char **EnumNamesBuiltinOptions() {
- static const char *names[] = {
+inline const char * const *EnumNamesBuiltinOptions() {
+ static const char * const names[] = {
"NONE",
"Conv2DOptions",
"DepthwiseConv2DOptions",
@@ -799,6 +823,9 @@ inline const char **EnumNamesBuiltinOptions() {
"LogicalNotOptions",
"UnpackOptions",
"FloorDivOptions",
+ "SquareOptions",
+ "ZerosLikeOptions",
+ "FillOptions",
nullptr
};
return names;
@@ -1073,6 +1100,18 @@ template<> struct BuiltinOptionsTraits<FloorDivOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions;
};
+template<> struct BuiltinOptionsTraits<SquareOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_SquareOptions;
+};
+
+template<> struct BuiltinOptionsTraits<ZerosLikeOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_ZerosLikeOptions;
+};
+
+template<> struct BuiltinOptionsTraits<FillOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_FillOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1624,6 +1663,30 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_FloorDivOptions ?
reinterpret_cast<const FloorDivOptionsT *>(value) : nullptr;
}
+ SquareOptionsT *AsSquareOptions() {
+ return type == BuiltinOptions_SquareOptions ?
+ reinterpret_cast<SquareOptionsT *>(value) : nullptr;
+ }
+ const SquareOptionsT *AsSquareOptions() const {
+ return type == BuiltinOptions_SquareOptions ?
+ reinterpret_cast<const SquareOptionsT *>(value) : nullptr;
+ }
+ ZerosLikeOptionsT *AsZerosLikeOptions() {
+ return type == BuiltinOptions_ZerosLikeOptions ?
+ reinterpret_cast<ZerosLikeOptionsT *>(value) : nullptr;
+ }
+ const ZerosLikeOptionsT *AsZerosLikeOptions() const {
+ return type == BuiltinOptions_ZerosLikeOptions ?
+ reinterpret_cast<const ZerosLikeOptionsT *>(value) : nullptr;
+ }
+ FillOptionsT *AsFillOptions() {
+ return type == BuiltinOptions_FillOptions ?
+ reinterpret_cast<FillOptionsT *>(value) : nullptr;
+ }
+ const FillOptionsT *AsFillOptions() const {
+ return type == BuiltinOptions_FillOptions ?
+ reinterpret_cast<const FillOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -1636,16 +1699,16 @@ enum Padding {
Padding_MAX = Padding_VALID
};
-inline Padding (&EnumValuesPadding())[2] {
- static Padding values[] = {
+inline const Padding (&EnumValuesPadding())[2] {
+ static const Padding values[] = {
Padding_SAME,
Padding_VALID
};
return values;
}
-inline const char **EnumNamesPadding() {
- static const char *names[] = {
+inline const char * const *EnumNamesPadding() {
+ static const char * const names[] = {
"SAME",
"VALID",
nullptr
@@ -1669,8 +1732,8 @@ enum ActivationFunctionType {
ActivationFunctionType_MAX = ActivationFunctionType_SIGN_BIT
};
-inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
- static ActivationFunctionType values[] = {
+inline const ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
+ static const ActivationFunctionType values[] = {
ActivationFunctionType_NONE,
ActivationFunctionType_RELU,
ActivationFunctionType_RELU_N1_TO_1,
@@ -1681,8 +1744,8 @@ inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] {
return values;
}
-inline const char **EnumNamesActivationFunctionType() {
- static const char *names[] = {
+inline const char * const *EnumNamesActivationFunctionType() {
+ static const char * const names[] = {
"NONE",
"RELU",
"RELU_N1_TO_1",
@@ -1707,8 +1770,8 @@ enum LSHProjectionType {
LSHProjectionType_MAX = LSHProjectionType_DENSE
};
-inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
- static LSHProjectionType values[] = {
+inline const LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
+ static const LSHProjectionType values[] = {
LSHProjectionType_UNKNOWN,
LSHProjectionType_SPARSE,
LSHProjectionType_DENSE
@@ -1716,8 +1779,8 @@ inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] {
return values;
}
-inline const char **EnumNamesLSHProjectionType() {
- static const char *names[] = {
+inline const char * const *EnumNamesLSHProjectionType() {
+ static const char * const names[] = {
"UNKNOWN",
"SPARSE",
"DENSE",
@@ -1738,16 +1801,16 @@ enum FullyConnectedOptionsWeightsFormat {
FullyConnectedOptionsWeightsFormat_MAX = FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
};
-inline FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] {
- static FullyConnectedOptionsWeightsFormat values[] = {
+inline const FullyConnectedOptionsWeightsFormat (&EnumValuesFullyConnectedOptionsWeightsFormat())[2] {
+ static const FullyConnectedOptionsWeightsFormat values[] = {
FullyConnectedOptionsWeightsFormat_DEFAULT,
FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8
};
return values;
}
-inline const char **EnumNamesFullyConnectedOptionsWeightsFormat() {
- static const char *names[] = {
+inline const char * const *EnumNamesFullyConnectedOptionsWeightsFormat() {
+ static const char * const names[] = {
"DEFAULT",
"SHUFFLED4x16INT8",
nullptr
@@ -1767,16 +1830,16 @@ enum LSTMKernelType {
LSTMKernelType_MAX = LSTMKernelType_BASIC
};
-inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
- static LSTMKernelType values[] = {
+inline const LSTMKernelType (&EnumValuesLSTMKernelType())[2] {
+ static const LSTMKernelType values[] = {
LSTMKernelType_FULL,
LSTMKernelType_BASIC
};
return values;
}
-inline const char **EnumNamesLSTMKernelType() {
- static const char *names[] = {
+inline const char * const *EnumNamesLSTMKernelType() {
+ static const char * const names[] = {
"FULL",
"BASIC",
nullptr
@@ -1797,8 +1860,8 @@ enum CombinerType {
CombinerType_MAX = CombinerType_SQRTN
};
-inline CombinerType (&EnumValuesCombinerType())[3] {
- static CombinerType values[] = {
+inline const CombinerType (&EnumValuesCombinerType())[3] {
+ static const CombinerType values[] = {
CombinerType_SUM,
CombinerType_MEAN,
CombinerType_SQRTN
@@ -1806,8 +1869,8 @@ inline CombinerType (&EnumValuesCombinerType())[3] {
return values;
}
-inline const char **EnumNamesCombinerType() {
- static const char *names[] = {
+inline const char * const *EnumNamesCombinerType() {
+ static const char * const names[] = {
"SUM",
"MEAN",
"SQRTN",
@@ -1827,15 +1890,15 @@ enum CustomOptionsFormat {
CustomOptionsFormat_MAX = CustomOptionsFormat_FLEXBUFFERS
};
-inline CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] {
- static CustomOptionsFormat values[] = {
+inline const CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] {
+ static const CustomOptionsFormat values[] = {
CustomOptionsFormat_FLEXBUFFERS
};
return values;
}
-inline const char **EnumNamesCustomOptionsFormat() {
- static const char *names[] = {
+inline const char * const *EnumNamesCustomOptionsFormat() {
+ static const char * const names[] = {
"FLEXBUFFERS",
nullptr
};
@@ -1880,13 +1943,13 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_MIN) &&
- verifier.Verify(min()) &&
+ verifier.VerifyVector(min()) &&
VerifyOffset(verifier, VT_MAX) &&
- verifier.Verify(max()) &&
+ verifier.VerifyVector(max()) &&
VerifyOffset(verifier, VT_SCALE) &&
- verifier.Verify(scale()) &&
+ verifier.VerifyVector(scale()) &&
VerifyOffset(verifier, VT_ZERO_POINT) &&
- verifier.Verify(zero_point()) &&
+ verifier.VerifyVector(zero_point()) &&
verifier.EndTable();
}
QuantizationParametersT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -1997,11 +2060,11 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SHAPE) &&
- verifier.Verify(shape()) &&
+ verifier.VerifyVector(shape()) &&
VerifyField<int8_t>(verifier, VT_TYPE) &&
VerifyField<uint32_t>(verifier, VT_BUFFER) &&
VerifyOffset(verifier, VT_NAME) &&
- verifier.Verify(name()) &&
+ verifier.VerifyString(name()) &&
VerifyOffset(verifier, VT_QUANTIZATION) &&
verifier.VerifyTable(quantization()) &&
VerifyField<uint8_t>(verifier, VT_IS_VARIABLE) &&
@@ -2318,12 +2381,16 @@ struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable {
int32_t stride_h;
int32_t depth_multiplier;
ActivationFunctionType fused_activation_function;
+ int32_t dilation_w_factor;
+ int32_t dilation_h_factor;
DepthwiseConv2DOptionsT()
: padding(Padding_SAME),
stride_w(0),
stride_h(0),
depth_multiplier(0),
- fused_activation_function(ActivationFunctionType_NONE) {
+ fused_activation_function(ActivationFunctionType_NONE),
+ dilation_w_factor(1),
+ dilation_h_factor(1) {
}
};
@@ -2334,7 +2401,9 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
VT_STRIDE_W = 6,
VT_STRIDE_H = 8,
VT_DEPTH_MULTIPLIER = 10,
- VT_FUSED_ACTIVATION_FUNCTION = 12
+ VT_FUSED_ACTIVATION_FUNCTION = 12,
+ VT_DILATION_W_FACTOR = 14,
+ VT_DILATION_H_FACTOR = 16
};
Padding padding() const {
return static_cast<Padding>(GetField<int8_t>(VT_PADDING, 0));
@@ -2351,6 +2420,12 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
+ int32_t dilation_w_factor() const {
+ return GetField<int32_t>(VT_DILATION_W_FACTOR, 1);
+ }
+ int32_t dilation_h_factor() const {
+ return GetField<int32_t>(VT_DILATION_H_FACTOR, 1);
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_PADDING) &&
@@ -2358,6 +2433,8 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tab
VerifyField<int32_t>(verifier, VT_STRIDE_H) &&
VerifyField<int32_t>(verifier, VT_DEPTH_MULTIPLIER) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_W_FACTOR) &&
+ VerifyField<int32_t>(verifier, VT_DILATION_H_FACTOR) &&
verifier.EndTable();
}
DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2383,6 +2460,12 @@ struct DepthwiseConv2DOptionsBuilder {
void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
+ void add_dilation_w_factor(int32_t dilation_w_factor) {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
+ }
+ void add_dilation_h_factor(int32_t dilation_h_factor) {
+ fbb_.AddElement<int32_t>(DepthwiseConv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
+ }
explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2401,8 +2484,12 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions(
int32_t stride_w = 0,
int32_t stride_h = 0,
int32_t depth_multiplier = 0,
- ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) {
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ int32_t dilation_w_factor = 1,
+ int32_t dilation_h_factor = 1) {
DepthwiseConv2DOptionsBuilder builder_(_fbb);
+ builder_.add_dilation_h_factor(dilation_h_factor);
+ builder_.add_dilation_w_factor(dilation_w_factor);
builder_.add_depth_multiplier(depth_multiplier);
builder_.add_stride_h(stride_h);
builder_.add_stride_w(stride_w);
@@ -2443,9 +2530,9 @@ struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Ta
return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_NUM_CHANNELS) &&
VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) &&
- verifier.Verify(num_columns_per_channel()) &&
+ verifier.VerifyVector(num_columns_per_channel()) &&
VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) &&
- verifier.Verify(embedding_dim_per_channel()) &&
+ verifier.VerifyVector(embedding_dim_per_channel()) &&
verifier.EndTable();
}
ConcatEmbeddingsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -3543,7 +3630,7 @@ struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_NEW_SHAPE) &&
- verifier.Verify(new_shape()) &&
+ verifier.VerifyVector(new_shape()) &&
verifier.EndTable();
}
ReshapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -4207,7 +4294,7 @@ struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_SQUEEZE_DIMS) &&
- verifier.Verify(squeeze_dims()) &&
+ verifier.VerifyVector(squeeze_dims()) &&
verifier.EndTable();
}
SqueezeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -5803,6 +5890,126 @@ inline flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(
flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct SquareOptionsT : public flatbuffers::NativeTable {
+ typedef SquareOptions TableType;
+ SquareOptionsT() {
+ }
+};
+
+struct SquareOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef SquareOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ SquareOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(SquareOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<SquareOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct SquareOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit SquareOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ SquareOptionsBuilder &operator=(const SquareOptionsBuilder &);
+ flatbuffers::Offset<SquareOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<SquareOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<SquareOptions> CreateSquareOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ SquareOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct ZerosLikeOptionsT : public flatbuffers::NativeTable {
+ typedef ZerosLikeOptions TableType;
+ ZerosLikeOptionsT() {
+ }
+};
+
+struct ZerosLikeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ZerosLikeOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ ZerosLikeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ZerosLikeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ZerosLikeOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ZerosLikeOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit ZerosLikeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ZerosLikeOptionsBuilder &operator=(const ZerosLikeOptionsBuilder &);
+ flatbuffers::Offset<ZerosLikeOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ZerosLikeOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ ZerosLikeOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct FillOptionsT : public flatbuffers::NativeTable {
+ typedef FillOptions TableType;
+ FillOptionsT() {
+ }
+};
+
+struct FillOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FillOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ FillOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FillOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FillOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FillOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit FillOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FillOptionsBuilder &operator=(const FillOptionsBuilder &);
+ flatbuffers::Offset<FillOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FillOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FillOptions> CreateFillOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ FillOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<FillOptions> CreateFillOptions(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5834,7 +6041,7 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<int8_t>(verifier, VT_BUILTIN_CODE) &&
VerifyOffset(verifier, VT_CUSTOM_CODE) &&
- verifier.Verify(custom_code()) &&
+ verifier.VerifyString(custom_code()) &&
VerifyField<int32_t>(verifier, VT_VERSION) &&
verifier.EndTable();
}
@@ -6131,6 +6338,15 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const FloorDivOptions *builtin_options_as_FloorDivOptions() const {
return builtin_options_type() == BuiltinOptions_FloorDivOptions ? static_cast<const FloorDivOptions *>(builtin_options()) : nullptr;
}
+ const SquareOptions *builtin_options_as_SquareOptions() const {
+ return builtin_options_type() == BuiltinOptions_SquareOptions ? static_cast<const SquareOptions *>(builtin_options()) : nullptr;
+ }
+ const ZerosLikeOptions *builtin_options_as_ZerosLikeOptions() const {
+ return builtin_options_type() == BuiltinOptions_ZerosLikeOptions ? static_cast<const ZerosLikeOptions *>(builtin_options()) : nullptr;
+ }
+ const FillOptions *builtin_options_as_FillOptions() const {
+ return builtin_options_type() == BuiltinOptions_FillOptions ? static_cast<const FillOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6144,17 +6360,17 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_OPCODE_INDEX) &&
VerifyOffset(verifier, VT_INPUTS) &&
- verifier.Verify(inputs()) &&
+ verifier.VerifyVector(inputs()) &&
VerifyOffset(verifier, VT_OUTPUTS) &&
- verifier.Verify(outputs()) &&
+ verifier.VerifyVector(outputs()) &&
VerifyField<uint8_t>(verifier, VT_BUILTIN_OPTIONS_TYPE) &&
VerifyOffset(verifier, VT_BUILTIN_OPTIONS) &&
VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) &&
VerifyOffset(verifier, VT_CUSTOM_OPTIONS) &&
- verifier.Verify(custom_options()) &&
+ verifier.VerifyVector(custom_options()) &&
VerifyField<int8_t>(verifier, VT_CUSTOM_OPTIONS_FORMAT) &&
VerifyOffset(verifier, VT_MUTATING_VARIABLE_INPUTS) &&
- verifier.Verify(mutating_variable_inputs()) &&
+ verifier.VerifyVector(mutating_variable_inputs()) &&
verifier.EndTable();
}
OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6422,6 +6638,18 @@ template<> inline const FloorDivOptions *Operator::builtin_options_as<FloorDivOp
return builtin_options_as_FloorDivOptions();
}
+template<> inline const SquareOptions *Operator::builtin_options_as<SquareOptions>() const {
+ return builtin_options_as_SquareOptions();
+}
+
+template<> inline const ZerosLikeOptions *Operator::builtin_options_as<ZerosLikeOptions>() const {
+ return builtin_options_as_ZerosLikeOptions();
+}
+
+template<> inline const FillOptions *Operator::builtin_options_as<FillOptions>() const {
+ return builtin_options_as_FillOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -6545,17 +6773,17 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_TENSORS) &&
- verifier.Verify(tensors()) &&
+ verifier.VerifyVector(tensors()) &&
verifier.VerifyVectorOfTables(tensors()) &&
VerifyOffset(verifier, VT_INPUTS) &&
- verifier.Verify(inputs()) &&
+ verifier.VerifyVector(inputs()) &&
VerifyOffset(verifier, VT_OUTPUTS) &&
- verifier.Verify(outputs()) &&
+ verifier.VerifyVector(outputs()) &&
VerifyOffset(verifier, VT_OPERATORS) &&
- verifier.Verify(operators()) &&
+ verifier.VerifyVector(operators()) &&
verifier.VerifyVectorOfTables(operators()) &&
VerifyOffset(verifier, VT_NAME) &&
- verifier.Verify(name()) &&
+ verifier.VerifyString(name()) &&
verifier.EndTable();
}
SubGraphT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6645,7 +6873,7 @@ struct Buffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_DATA) &&
- verifier.Verify(data()) &&
+ verifier.VerifyVector(data()) &&
verifier.EndTable();
}
BufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6734,18 +6962,18 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
return VerifyTableStart(verifier) &&
VerifyField<uint32_t>(verifier, VT_VERSION) &&
VerifyOffset(verifier, VT_OPERATOR_CODES) &&
- verifier.Verify(operator_codes()) &&
+ verifier.VerifyVector(operator_codes()) &&
verifier.VerifyVectorOfTables(operator_codes()) &&
VerifyOffset(verifier, VT_SUBGRAPHS) &&
- verifier.Verify(subgraphs()) &&
+ verifier.VerifyVector(subgraphs()) &&
verifier.VerifyVectorOfTables(subgraphs()) &&
VerifyOffset(verifier, VT_DESCRIPTION) &&
- verifier.Verify(description()) &&
+ verifier.VerifyString(description()) &&
VerifyOffset(verifier, VT_BUFFERS) &&
- verifier.Verify(buffers()) &&
+ verifier.VerifyVector(buffers()) &&
verifier.VerifyVectorOfTables(buffers()) &&
VerifyOffset(verifier, VT_METADATA_BUFFER) &&
- verifier.Verify(metadata_buffer()) &&
+ verifier.VerifyVector(metadata_buffer()) &&
verifier.EndTable();
}
ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -6996,6 +7224,8 @@ inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const
{ auto _e = stride_h(); _o->stride_h = _e; };
{ auto _e = depth_multiplier(); _o->depth_multiplier = _e; };
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = dilation_w_factor(); _o->dilation_w_factor = _e; };
+ { auto _e = dilation_h_factor(); _o->dilation_h_factor = _e; };
}
inline flatbuffers::Offset<DepthwiseConv2DOptions> DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7011,13 +7241,17 @@ inline flatbuffers::Offset<DepthwiseConv2DOptions> CreateDepthwiseConv2DOptions(
auto _stride_h = _o->stride_h;
auto _depth_multiplier = _o->depth_multiplier;
auto _fused_activation_function = _o->fused_activation_function;
+ auto _dilation_w_factor = _o->dilation_w_factor;
+ auto _dilation_h_factor = _o->dilation_h_factor;
return tflite::CreateDepthwiseConv2DOptions(
_fbb,
_padding,
_stride_w,
_stride_h,
_depth_multiplier,
- _fused_activation_function);
+ _fused_activation_function,
+ _dilation_w_factor,
+ _dilation_h_factor);
}
inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -8661,6 +8895,75 @@ inline flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::F
_fbb);
}
+inline SquareOptionsT *SquareOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new SquareOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void SquareOptions::UnPackTo(SquareOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<SquareOptions> SquareOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateSquareOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<SquareOptions> CreateSquareOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquareOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SquareOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateSquareOptions(
+ _fbb);
+}
+
+inline ZerosLikeOptionsT *ZerosLikeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new ZerosLikeOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void ZerosLikeOptions::UnPackTo(ZerosLikeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<ZerosLikeOptions> ZerosLikeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateZerosLikeOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ZerosLikeOptions> CreateZerosLikeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ZerosLikeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ZerosLikeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateZerosLikeOptions(
+ _fbb);
+}
+
+inline FillOptionsT *FillOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FillOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FillOptions::UnPackTo(FillOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<FillOptions> FillOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFillOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FillOptions> CreateFillOptions(flatbuffers::FlatBufferBuilder &_fbb, const FillOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FillOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateFillOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -9110,6 +9413,18 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_SquareOptions: {
+ auto ptr = reinterpret_cast<const SquareOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_ZerosLikeOptions: {
+ auto ptr = reinterpret_cast<const ZerosLikeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_FillOptions: {
+ auto ptr = reinterpret_cast<const FillOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9388,6 +9703,18 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_SquareOptions: {
+ auto ptr = reinterpret_cast<const SquareOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_ZerosLikeOptions: {
+ auto ptr = reinterpret_cast<const ZerosLikeOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_FillOptions: {
+ auto ptr = reinterpret_cast<const FillOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9654,6 +9981,18 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const FloorDivOptionsT *>(value);
return CreateFloorDivOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_SquareOptions: {
+ auto ptr = reinterpret_cast<const SquareOptionsT *>(value);
+ return CreateSquareOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_ZerosLikeOptions: {
+ auto ptr = reinterpret_cast<const ZerosLikeOptionsT *>(value);
+ return CreateZerosLikeOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_FillOptions: {
+ auto ptr = reinterpret_cast<const FillOptionsT *>(value);
+ return CreateFillOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9920,6 +10259,18 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new FloorDivOptionsT(*reinterpret_cast<FloorDivOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_SquareOptions: {
+ value = new SquareOptionsT(*reinterpret_cast<SquareOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_ZerosLikeOptions: {
+ value = new ZerosLikeOptionsT(*reinterpret_cast<ZerosLikeOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_FillOptions: {
+ value = new FillOptionsT(*reinterpret_cast<FillOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -10252,6 +10603,21 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_SquareOptions: {
+ auto ptr = reinterpret_cast<SquareOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_ZerosLikeOptions: {
+ auto ptr = reinterpret_cast<ZerosLikeOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_FillOptions: {
+ auto ptr = reinterpret_cast<FillOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
@@ -10262,6 +10628,10 @@ inline const tflite::Model *GetModel(const void *buf) {
return flatbuffers::GetRoot<tflite::Model>(buf);
}
+inline const tflite::Model *GetSizePrefixedModel(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<tflite::Model>(buf);
+}
+
inline const char *ModelIdentifier() {
return "TFL3";
}
@@ -10276,6 +10646,11 @@ inline bool VerifyModelBuffer(
return verifier.VerifyBuffer<tflite::Model>(ModelIdentifier());
}
+inline bool VerifySizePrefixedModelBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<tflite::Model>(ModelIdentifier());
+}
+
inline const char *ModelExtension() {
return "tflite";
}
@@ -10286,6 +10661,12 @@ inline void FinishModelBuffer(
fbb.Finish(root, ModelIdentifier());
}
+inline void FinishSizePrefixedModelBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<tflite::Model> root) {
+ fbb.FinishSizePrefixed(root, ModelIdentifier());
+}
+
inline std::unique_ptr<ModelT> UnPackModel(
const void *buf,
const flatbuffers::resolver_function_t *res = nullptr) {
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 3a6c16cafc..a4736bfee9 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -7,7 +7,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow/contrib/lite:build_def.bzl",
"gen_zip_test",
- "generated_test_models",
+ "generated_test_models_all",
)
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
load(
@@ -29,6 +29,7 @@ load(
"--unzip_binary_path=/usr/bin/unzip",
],
}),
+ conversion_mode = conversion_mode,
data = [
":zip_%s" % test_name,
],
@@ -59,7 +60,7 @@ load(
"//tensorflow/core:android_tensorflow_test_lib",
],
}),
-) for test_name in generated_test_models()]
+) for conversion_mode, test_name in generated_test_models_all()]
test_suite(
name = "generated_zip_tests",
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 32f02a4f6c..014c80b5ef 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -80,7 +80,10 @@ parser.add_argument(
"--save_graphdefs",
action="store_true",
help="Include intermediate graphdefs in the output zip files.")
-
+parser.add_argument(
+ "--run_with_extended",
+ action="store_true",
+ help="Whether the TFLite Extended converter is being used.")
RANDOM_SEED = 342
TEST_INPUT_DEPTH = 3
@@ -320,10 +323,11 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
output tflite model, log_txt from conversion
or None, log_txt if it did not convert properly.
"""
+ input_arrays = [x[0] for x in input_tensors]
data_types = [_TF_TYPE_INFO[x[2]][1] for x in input_tensors]
opts = toco_options(
data_types=data_types,
- input_arrays=[x[0] for x in input_tensors],
+ input_arrays=input_arrays,
shapes=[x[1] for x in input_tensors],
output_arrays=output_tensors,
extra_toco_options=extra_toco_options)
@@ -335,6 +339,11 @@ def toco_convert(graph_def_str, input_tensors, output_tensors,
graphdef_file.flush()
# TODO(aselle): Switch this to subprocess at some point.
+ if "pb2lite" in bin_path and FLAGS.run_with_extended:
+ opts = ("--input_arrays={0} --output_arrays={1}".format(
+ ",".join(input_arrays), ",".join(output_tensors)))
+ elif FLAGS.run_with_extended:
+ opts += " --allow_eager_ops --force_eager_ops"
cmd = ("%s --input_file=%s --output_file=%s %s > %s 2>&1" %
(bin_path, graphdef_file.name, output_file.name, opts,
stdout_file.name))
@@ -1425,6 +1434,7 @@ def make_depthwiseconv_tests(zip_path):
"input_shape": [[1, 3, 4, 3], [1, 10, 10, 3]],
"filter_size": [[1, 1], [1, 2], [3, 3]],
"strides": [[1, 1, 1, 1], [1, 3, 3, 1]],
+ "dilations": [[1, 1, 1, 1], [1, 3, 2, 1], [1, 2, 2, 1]],
"channel_multiplier": [1, 2],
"rate": [[1, 1]],
"padding": ["SAME", "VALID"],
@@ -1435,6 +1445,7 @@ def make_depthwiseconv_tests(zip_path):
"input_shape": [[1, 3, 4, 3]],
"filter_size": [[1, 1]],
"strides": [[1, 1, 2, 1]], # TF needs [1, x, x, 1]
+ "dilations": [[1, 1, 1, 1], [1, 2, 2, 1]],
"channel_multiplier": [2],
"rate": [[2, 2]], # Only [1, 1] is supported
"padding": ["SAME"],
@@ -1502,7 +1513,7 @@ def make_split_tests(zip_path):
dtype=tf.float32, name="input", shape=parameters["input_shape"])
out = tf.split(
input_tensor, parameters["num_or_size_splits"], parameters["axis"])
- return [input_tensor], out
+ return [input_tensor], [out[0]]
def build_inputs(parameters, sess, inputs, outputs):
values = [create_tensor_data(np.float32, parameters["input_shape"])]
@@ -2510,10 +2521,12 @@ def make_topk_tests(zip_path):
shape=parameters["input_shape"])
if parameters["input_k"] is not None:
k = tf.placeholder(dtype=tf.int32, name="input_k", shape=[])
+ inputs = [input_value, k]
else:
k = tf.constant(3, name="k")
+ inputs = [input_value]
out = tf.nn.top_k(input_value, k)
- return [input_value, k], [out[1]]
+ return inputs, [out[1]]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(parameters["input_dtype"],
@@ -2821,6 +2834,31 @@ def make_neg_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_zeros_like_tests(zip_path):
+ """Make a set of tests to do zeros_like."""
+
+ test_parameters = [{
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape": [[], [1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
+ }]
+
+ def build_graph(parameters):
+ """Build the zeros_like op testing graph."""
+ input_tensor = tf.placeholder(
+ dtype=parameters["input_dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+ out = tf.zeros_like(input_tensor)
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ values = create_tensor_data(parameters["input_dtype"],
+ parameters["input_shape"])
+ return [values], sess.run(outputs, feed_dict=dict(zip(inputs, [values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def _make_elementwise_tests(op):
"""Make a set of tests to do element-wise operations."""
@@ -2871,6 +2909,11 @@ def make_rsqrt_tests(zip_path):
return _make_elementwise_tests(tf.rsqrt)(zip_path)
+def make_square_tests(zip_path):
+ """Make a set of tests to do square."""
+ return _make_elementwise_tests(tf.square)(zip_path)
+
+
def make_where_tests(zip_path):
"""Make a set of tests to do where."""
@@ -3208,7 +3251,7 @@ def make_unpack_tests(zip_path):
input_tensor = tf.placeholder(
dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
- return [input_tensor], outs
+ return [input_tensor], [outs[0]]
def build_inputs(parameters, sess, inputs, outputs):
input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
@@ -3286,7 +3329,11 @@ def main(unused_args):
out = FLAGS.zip_to_output
bin_path = FLAGS.toco
- test_function = ("make_%s_tests" % out.replace(".zip", ""))
+ # Some zip filenames contain a postfix identifying the conversion mode. The
+ # list of valid conversion modes is defined in
+ # generated_test_conversion_modes() in build_def.bzl.
+ test_function = ("make_%s_tests" % (out.replace(".zip", "").replace(
+ "pb2lite", "").replace("toco-extended", "").rstrip("_")))
if test_function not in globals():
raise RuntimeError("Can't find a test function to create %r. Tried %r" %
(out, test_function))
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index bea90f1ce8..96b88b60fc 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -347,6 +347,7 @@ tf_cc_test(
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
+ "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"@com_google_googletest//:gtest_main",
],
@@ -407,8 +408,11 @@ tf_cc_binary(
":toco_port",
":toco_tooling",
":types_proto_cc",
- "//tensorflow/core:lib",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:lib",
+ # We cannot embed the core:ops dependency directly into :toco_tooling as
+ # it can conflict with downstream deps when toco is used as a library.
+ "//tensorflow/core:ops",
],
)
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index b52a79282c..61e9106783 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -470,6 +470,17 @@ void ConvertDepthwiseConvOperator(const Model& model,
strides.mutable_list()->add_i(src_op.stride_height);
strides.mutable_list()->add_i(src_op.stride_width);
strides.mutable_list()->add_i(1);
+ // TODO(b/116063589): To return a working TF GraphDef, we should be returning
+ // the correct SpaceToBatchNd and BatchToSpaceND operation before and after
+ // the conv since TF doesn't support dilations.
+ if ((src_op.dilation_width_factor != 1) ||
+ (src_op.dilation_height_factor != 1)) {
+ auto& dilations = (*dc2d_op->mutable_attr())["dilations"];
+ dilations.mutable_list()->add_i(1);
+ dilations.mutable_list()->add_i(src_op.dilation_height_factor);
+ dilations.mutable_list()->add_i(src_op.dilation_width_factor);
+ dilations.mutable_list()->add_i(1);
+ }
string padding;
if (src_op.padding.type == PaddingType::kSame) {
padding = "SAME";
@@ -1968,6 +1979,19 @@ void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
(*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
}
+void ConvertZerosLikeOperator(const Model& model,
+ const TensorFlowZerosLikeOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* zeros_like_op = tensorflow_graph->add_node();
+ zeros_like_op->set_op(op_name);
+ zeros_like_op->set_name(src_op.outputs[0]);
+ DCHECK_EQ(src_op.inputs.size(), 1);
+ *zeros_like_op->add_input() = src_op.inputs[0];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*zeros_like_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2233,6 +2257,10 @@ void ConvertOperator(const Model& model, const Operator& src_op,
} else if (src_op.type == OperatorType::kUnpack) {
ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
"Unpack", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kZerosLike) {
+ ConvertZerosLikeOperator(
+ model, static_cast<const TensorFlowZerosLikeOperator&>(src_op),
+ "ZerosLike", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 84680b968e..aba7536cbd 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -38,7 +38,7 @@ There are two approaches to running TOCO via command line.
examples below use `tflite_convert` for simplicity.
* Example: `tflite_convert --output_file=...`
* `bazel`: In order to run the latest version of TOCO, [clone the TensorFlow
- repository](https://www.tensorflow.org/install/install_sources#clone_the_tensorflow_repository)
+ repository](https://www.tensorflow.org/install/source)
and use `bazel`. This is the recommended approach for converting models that
utilize new features that were not supported by TOCO in TensorFlow 1.9.
* Example: `bazel run
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 51f808d4f0..910fa4c8de 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -260,7 +260,7 @@ interpreter.allocate_tensors()
In order to run the latest version of the TOCO Python API, clone the TensorFlow
repository, configure the installation, and build and install the pip package.
Detailed instructions are available
-[here](https://www.tensorflow.org/install/install_sources).
+[here](https://www.tensorflow.org/install/source).
### Converting models prior to TensorFlow 1.9. <a name="pre-tensorflow-1.9"></a>
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index fdd0632451..4d213b3f9c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -133,7 +133,6 @@ DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs)
DECLARE_GRAPH_TRANSFORMATION(MergeReshapeIntoPrecedingTranspose)
DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
DECLARE_GRAPH_TRANSFORMATION(IdentifyPRelu)
-DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(MoveBinaryOperatorBeforeReshape)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
@@ -266,6 +265,17 @@ class EnsureUint8WeightsSafeForFastInt8Kernels : public GraphTransformation {
bool has_default_ranges_flag_ = false;
};
+class IdentifyDilatedConv : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "IdentifyDilatedConv"; }
+ bool identify_depthwise_conv() const { return identify_depthwise_conv_; }
+ void set_identify_depthwise_conv(bool val) { identify_depthwise_conv_ = val; }
+
+ private:
+ bool identify_depthwise_conv_ = true;
+};
+
#undef DECLARE_GRAPH_TRANSFORMATION
} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
index d49857cfc2..aac77eb39e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc
@@ -53,50 +53,11 @@ namespace toco {
// thrown in just for the extra headache. Padding adapts non-conforming input
// sizes, and can be discarded. The bias is necessary, so is kept.
-bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
- const auto it = model->operators.begin() + op_index;
- auto* stb_op = it->get();
-
- // 1. IDENTIFY OPERATORS
- // ***************************************************************************
- // SpaceToBatch Op.
- if (stb_op->type != OperatorType::kSpaceToBatchND) {
- return false;
- }
- if (stb_op->inputs.size() != 3) {
- return false;
- }
- CHECK_EQ(stb_op->outputs.size(), 1);
- // Extract the dilation factor from Input[1] of SpaceToBatch
- // TODO(mjmatthews): Support 2D dilation factors.
- const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
- if (!block_shape_array.buffer) {
- return false;
- }
- CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
- int dilation_factor =
- block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
-
- // Expand Op
- auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
- if (!post_stb_op) {
- return false;
- }
- bool has_expand_op = false;
- if (post_stb_op->type == OperatorType::kExpandDims) {
- has_expand_op = true;
- CHECK_EQ(post_stb_op->inputs.size(), 2);
- CHECK_EQ(post_stb_op->outputs.size(), 1);
- }
-
- // Conv Op
- const string& input_of_conv_op =
- has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
- auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
- if (conv_base_op->type != OperatorType::kConv) {
- return false;
- }
- auto* conv_op = static_cast<ConvOperator*>(conv_base_op);
+template <typename T>
+bool ResolveDilatedConv(Model* model, Operator* conv_base_op, Operator* stb_op,
+ Operator* post_stb_op, bool has_expand_op,
+ int dilation_factor) {
+ auto* conv_op = static_cast<T*>(conv_base_op);
if (conv_op->inputs.size() != 2) {
// The conv op must only have weights, no bias.
return false;
@@ -158,8 +119,6 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
CHECK_EQ(bias_add_op->inputs.size(), 2);
CHECK_EQ(bias_add_op->outputs.size(), 1);
- LOG(INFO) << "Identified sub-network emulating dilated convolution.";
-
// 2. RE-WIRE OPERATORS
// ***************************************************************************
// Re-use the existing Conv2D op.
@@ -206,9 +165,71 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
DeleteArrayIfUnused(stb_op_inputs[1], model);
DeleteArrayIfUnused(stb_op_inputs[2], model);
- LOG(INFO) << "Replaced with Dilated Conv2D op outputting \""
- << conv_op->outputs[0] << "\".";
return true;
}
+bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* stb_op = it->get();
+
+ // 1. IDENTIFY OPERATORS
+ // ***************************************************************************
+ // SpaceToBatch Op.
+ if (stb_op->type != OperatorType::kSpaceToBatchND) {
+ return false;
+ }
+ if (stb_op->inputs.size() != 3) {
+ return false;
+ }
+ CHECK_EQ(stb_op->outputs.size(), 1);
+ // Extract the dilation factor from Input[1] of SpaceToBatch
+ // TODO(mjmatthews): Support 2D dilation factors.
+ const auto& block_shape_array = model->GetArray(stb_op->inputs[1]);
+ if (!block_shape_array.buffer) {
+ return false;
+ }
+ CHECK_EQ(block_shape_array.shape().dimensions_count(), 1);
+ int dilation_factor =
+ block_shape_array.Array::GetBuffer<ArrayDataType::kInt32>().data[0];
+
+ // Expand Op
+ auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]);
+ if (!post_stb_op) {
+ return false;
+ }
+ bool has_expand_op = false;
+ if (post_stb_op->type == OperatorType::kExpandDims) {
+ has_expand_op = true;
+ CHECK_EQ(post_stb_op->inputs.size(), 2);
+ CHECK_EQ(post_stb_op->outputs.size(), 1);
+ }
+
+ // Conv Op
+ const string& input_of_conv_op =
+ has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0];
+ auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op);
+ bool changed = false;
+ if (conv_base_op->type == OperatorType::kConv) {
+ changed = ResolveDilatedConv<ConvOperator>(model, conv_base_op, stb_op,
+ post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO) << "Replaced sub-network with Dilated Conv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ } else if (identify_depthwise_conv_ &&
+ conv_base_op->type == OperatorType::kDepthwiseConv) {
+ changed = ResolveDilatedConv<DepthwiseConvOperator>(
+ model, conv_base_op, stb_op, post_stb_op, has_expand_op,
+ dilation_factor);
+ if (changed) {
+ LOG(INFO)
+ << "Replaced sub-netork with Dilated DepthwiseConv2D op outputting \""
+ << conv_base_op->outputs[0] << "\".";
+ }
+ }
+
+ return changed;
+}
+
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index f103bb94ae..d056a8add7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -285,7 +285,8 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
- op->stride_height, 1, 1, op->padding.type,
+ op->stride_height, op->dilation_width_factor,
+ op->dilation_height_factor, op->padding.type,
model->GetArray(output_name).mutable_shape(),
&op->padding.GetOrCreateFixedPadding());
}
@@ -658,11 +659,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
}
}
auto& output_array = model->GetArray(op->outputs[0]);
- // Use 0 input as basis for output dimensions.
- const auto& first_input_array = model->GetArray(op->inputs[0]);
- output_array.copy_shape(first_input_array.shape());
- // Negative axis means the count starts at the back of the dims().
- if (op->axis < 0) op->axis += first_input_array.shape().dims().size();
+ // Use first non-empty input as basis for output dimensions.
+ for (const auto& input_name : op->inputs) {
+ const auto& input_array = model->GetArray(input_name);
+ if (input_array.shape().dimensions_count() > 0) {
+ output_array.copy_shape(input_array.shape());
+ // Negative axis means the count starts at the back of the dims().
+ if (op->axis < 0) op->axis += input_array.shape().dims().size();
+ break;
+ }
+ }
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
@@ -1655,6 +1661,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kLogicalAnd:
case OperatorType::kLogicalNot:
case OperatorType::kLogicalOr:
+ case OperatorType::kZerosLike:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather:
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
index 8266e2c205..8e150db6fa 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -25,29 +25,57 @@ limitations under the License.
namespace toco {
+namespace {
+
+void RenameArray(Model* model, const string& oldname,
+ const string& desired_newname) {
+ const string& newname = AvailableArrayName(*model, desired_newname);
+ auto& arrays = model->GetMutableArrayMap();
+ arrays[newname] = std::move(arrays[oldname]);
+ arrays.erase(oldname);
+ for (const auto& op : model->operators) {
+ for (string& input : op->inputs) {
+ if (input == oldname) {
+ input = newname;
+ }
+ }
+ for (string& output : op->outputs) {
+ if (output == oldname) {
+ output = newname;
+ }
+ }
+ }
+}
+
+} // namespace
+
// Reorder the elements of an input_array according to the input_axes_order and
// output_axes_order. Then adjust the shapes of the input and output arrays
// accordingly. Note that input_array must have a buffer (that is, it is a
// constant array).
template <typename T, ArrayDataType DataType>
void ReorderAxes(AxesOrder input_axes_order, AxesOrder output_axes_order,
- Array* input_array, Array* output_array) {
- CHECK(input_array->buffer->type == DataType);
- CHECK(!output_array->buffer);
- auto& input_data = input_array->GetMutableBuffer<DataType>().data;
- std::vector<T> reordered_data;
- reordered_data.resize(RequiredBufferSizeForShape(output_array->shape()));
+ const Array& input_array, Array* output_array) {
+ DCHECK(input_array.buffer->type == DataType);
+ DCHECK(!output_array->buffer);
+ const auto& input_data = input_array.GetBuffer<DataType>().data;
+ auto& output_data = output_array->GetMutableBuffer<DataType>().data;
+ output_data.resize(RequiredBufferSizeForShape(output_array->shape()));
// TODO(b/62904716) Shapes should be used directly.
- Shape input_shape = input_array->shape();
+ Shape input_shape = input_array.shape();
Shape output_shape = output_array->shape();
if (AxesCount(input_axes_order) == 2) {
UnextendShape(&input_shape, 2);
UnextendShape(&output_shape, 2);
}
ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
- input_data.data(), reordered_data.data());
- input_data = reordered_data;
- input_array->copy_shape(output_array->shape());
+ input_data.data(), output_data.data());
+ if (input_array.minmax) {
+ output_array->GetOrCreateMinMax() = input_array.GetMinMax();
+ }
+ if (input_array.narrow_range) {
+ output_array->narrow_range = true;
+ }
}
bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
@@ -57,8 +85,11 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
return false;
}
auto* reorder_op = static_cast<ReorderAxesOperator*>(op);
- const auto& input_array_name = reorder_op->inputs[0];
- const auto& output_array_name = reorder_op->outputs[0];
+
+ // Intentionally copies, not references.
+ const string input_array_name = reorder_op->inputs[0];
+ const string output_array_name = reorder_op->outputs[0];
+
auto& input_array = model->GetArray(input_array_name);
auto& output_array = model->GetArray(output_array_name);
if (!input_array.buffer) {
@@ -72,31 +103,23 @@ bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
if (input_array.buffer->type == ArrayDataType::kFloat) {
ReorderAxes<float, ArrayDataType::kFloat>(reorder_op->input_axes_order,
reorder_op->output_axes_order,
- &input_array, &output_array);
- } else if (input_array.buffer->type == ArrayDataType::kInt32) {
+ input_array, &output_array);
+ } else if (input_array.buffer->type == ArrayDataType::kUint8) {
+ // TODO(benoitjacob): This path seems unused.
+ // ReorderAxes is only used when importing from
+ // TensorFlow GraphDef, which does not support quantized nodes.
ReorderAxes<uint8, ArrayDataType::kUint8>(reorder_op->input_axes_order,
reorder_op->output_axes_order,
- &input_array, &output_array);
+ input_array, &output_array);
} else {
LOG(FATAL) << "Cannot ReorderAxes unless input buffer is float or uint8.";
}
- input_array.copy_shape(output_array.shape());
-
- // Update the edges of the graph to point to the input array
- for (const auto& other_op : model->operators) {
- for (auto& input : other_op->inputs) {
- if (input == output_array_name) {
- input = input_array_name;
- }
- }
- }
-
AddMessageF("Reordered axes for array %s", input_array_name);
- // Remove the op and output array.
- model->EraseArray(output_array_name);
- model->operators.erase(it);
+ DeleteOpAndArraysIfUnused(model, op);
+ RenameArray(model, output_array_name, input_array_name);
+
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
index fcf30bd347..65346c4fe4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -24,6 +24,37 @@ limitations under the License.
namespace toco {
+namespace {
+
+TransposeOperator* FindTransposeOpWithInput(const Model& model,
+ const string& array_name) {
+ for (auto it = model.operators.begin(); it != model.operators.end(); ++it) {
+ Operator* op = it->get();
+ if (op->type != OperatorType::kTranspose) {
+ continue;
+ }
+ if (op->inputs[0] != array_name) {
+ continue;
+ }
+ const auto& permutation_array = model.GetArray(op->inputs[1]);
+ if (permutation_array.data_type != ArrayDataType::kInt32) {
+ continue;
+ }
+ const auto& permutation_data =
+ permutation_array.GetBuffer<ArrayDataType::kInt32>().data;
+ if (permutation_data.size() != 2) {
+ continue;
+ }
+ if (permutation_data[0] != 1 || permutation_data[1] != 0) {
+ continue;
+ }
+ return static_cast<TransposeOperator*>(op);
+ }
+ return nullptr;
+}
+
+} // namespace
+
bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
auto matmul_it = model->operators.begin() + op_index;
if (matmul_it->get()->type != OperatorType::kMatMul) {
@@ -37,7 +68,13 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
// TransposeOperator. However, the second input is supposed to be 2D, so we
// can actually handle transposition of that matrix, which happens to be more
// common anyway.
- CHECK(!matmul_op->transpose_a);
+ if (matmul_op->transpose_a) {
+ AddMessageF(
+ "Not replacing %s by a FullyConnected operator, because it has "
+ "the transpose_a attribute",
+ LogName(*matmul_op));
+ return false;
+ }
// Reorder the axes on the second input. TensorFlow uses row-major ordering
// on both inputs, however this is inefficient for the FullyConnected
@@ -46,18 +83,35 @@ bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
string input_lhs = matmul_op->inputs[0];
string input_rhs = matmul_op->inputs[1];
if (!matmul_op->transpose_b) {
- auto* transpose_op = new TransposeOperator;
- transpose_op->inputs = {
- matmul_op->inputs[1],
- CreateInt32Array(model,
- AvailableArrayName(
- *model, matmul_op->inputs[1] + "/transpose/perm"),
- {1, 0})};
- transpose_op->outputs = {
- AvailableArrayName(*model, matmul_op->inputs[1] + "/transpose")};
- model->GetOrCreateArray(transpose_op->outputs[0]);
- model->operators.emplace(matmul_it, transpose_op);
-
+ // Need to transpose input_rhs, by inserting a TransposeOperator.
+ // First, check if there already is a TransposeOperator transposing that
+ // array, so we can just reuse it.
+ auto* transpose_op = FindTransposeOpWithInput(*model, input_rhs);
+ if (!transpose_op) {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, created new "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ // No such TransposeOperator found. Create one now.
+ transpose_op = new TransposeOperator;
+ transpose_op->inputs = {
+ input_rhs,
+ CreateInt32Array(
+ model, AvailableArrayName(*model, input_rhs + "/transpose/perm"),
+ {1, 0})};
+ transpose_op->outputs = {
+ AvailableArrayName(*model, input_rhs + "/transpose")};
+ model->GetOrCreateArray(transpose_op->outputs[0]);
+ model->operators.emplace(matmul_it, transpose_op);
+ // Sanity check
+ DCHECK_EQ(transpose_op, FindTransposeOpWithInput(*model, input_rhs));
+ } else {
+ AddMessageF(
+ "While replacing %s by a FullyConnected operator, reused existing "
+ "Transpose op wrapping RHS input array %s",
+ LogName(*matmul_op), input_rhs);
+ }
+ // Re-wire: have the matmul consume the transposed array.
input_rhs = transpose_op->outputs[0];
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9bc23c4b3c..e02d000e7e 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -58,6 +58,7 @@ using tensorflow::DT_STRING;
using tensorflow::DT_UINT8;
using tensorflow::GraphDef;
using tensorflow::NodeDef;
+using tensorflow::OpRegistry;
using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
@@ -68,6 +69,13 @@ bool HasAttr(const NodeDef& node, const string& attr_name) {
return node.attr().count(attr_name) > 0;
}
+bool HasWildcardDimension(const TensorShapeProto& shape) {
+ for (const auto& dim : shape.dim()) {
+ if (dim.size() == -1) return true;
+ }
+ return false;
+}
+
const string& GetStringAttr(const NodeDef& node, const string& attr_name) {
CHECK(HasAttr(node, attr_name));
const auto& attr = node.attr().at(attr_name);
@@ -633,6 +641,23 @@ tensorflow::Status ConvertDepthwiseConvOperator(
CHECK_EQ(strides.i(3), 1);
conv->stride_height = strides.i(1);
conv->stride_width = strides.i(2);
+ if (HasAttr(node, "dilations")) {
+ const auto& dilations = GetListAttr(node, "dilations");
+ TF_RETURN_IF_ERROR(
+ ExpectValue(dilations.i_size(), 4, "number of dilations"));
+ if (dilations.i(0) != 1 || dilations.i(3) != 1) {
+ return tensorflow::errors::InvalidArgument(absl::StrCat(
+ "Can only import Conv ops with dilation along the height "
+ "(1st) or width (2nd) axis. TensorFlow op \"",
+ node.name(), "\" had dilations:[ ", dilations.i(0), ", ",
+ dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "]."));
+ }
+ conv->dilation_height_factor = dilations.i(1);
+ conv->dilation_width_factor = dilations.i(2);
+ } else {
+ conv->dilation_height_factor = 1;
+ conv->dilation_width_factor = 1;
+ }
const auto& padding = GetStringAttr(node, "padding");
if (padding == "SAME") {
conv->padding.type = PaddingType::kSame;
@@ -1053,15 +1078,27 @@ tensorflow::Status ConvertUnsupportedOperator(
"_support_output_type_float_in_quantized_op";
LOG(INFO) << "Converting unsupported operation: " << node.op();
+
auto* op = new TensorFlowUnsupportedOperator;
+ op->tensorflow_op = node.op();
+ node.SerializeToString(&op->tensorflow_node_def);
+ model->operators.emplace_back(op);
+
+ // Parse inputs.
const int num_inputs = GetInputsCount(node, tf_import_flags);
for (int i = 0; i < num_inputs; ++i) {
op->inputs.push_back(node.input(i));
}
- op->outputs.push_back(node.name());
- op->tensorflow_op = node.op();
- node.SerializeToString(&op->tensorflow_node_def);
- model->operators.emplace_back(op);
+
+ // Parse outputs.
+ op->outputs.push_back(node.name()); // Implicit :0.
+ const tensorflow::OpDef* op_def = nullptr;
+ if (tensorflow::OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+ for (int i = 1; i < op_def->output_arg_size(); ++i) {
+ op->outputs.push_back(absl::StrCat(node.name(), ":", i));
+ }
+ }
+
// Parse if the op supports quantization
if (HasAttr(node, kAttrOutputQuantized)) {
op->quantized = GetBoolAttr(node, kAttrOutputQuantized);
@@ -1071,6 +1108,8 @@ tensorflow::Status ConvertUnsupportedOperator(
op->support_output_type_float_in_quantized_op =
GetBoolAttr(node, kAttrSupportOutputTypeFloatInQuantizedOp);
}
+
+ // Parse output type(s).
if (HasAttr(node, kAttrOutputTypes)) {
const auto& output_types = GetListAttr(node, kAttrOutputTypes);
for (int i = 0; i < output_types.type_size(); ++i) {
@@ -1079,14 +1118,40 @@ tensorflow::Status ConvertUnsupportedOperator(
} else if (HasAttr(node, "Tout")) {
const auto& output_type = GetDataTypeAttr(node, "Tout");
op->output_data_types.push_back(ConvertDataType(output_type));
+ } else if (op_def != nullptr) {
+ for (const auto& output_arg : op_def->output_arg()) {
+ if (HasAttr(node, output_arg.type_attr())) {
+ op->output_data_types.push_back(
+ ConvertDataType(GetDataTypeAttr(node, output_arg.type_attr())));
+ } else {
+ LOG(INFO) << "Op node missing output type attribute: " << node.name();
+ op->output_data_types.clear();
+ break;
+ }
+ }
+ } else {
+ // TODO(b/113613439): Figure out how to propagate types for custom ops
+ // that have no OpDef.
+ LOG(INFO) << "Unable to determine output type for op: " << node.op();
}
+
+ // Parse output shape(s).
if (HasAttr(node, kAttrOutputShapes)) {
const auto& output_shapes = GetListAttr(node, kAttrOutputShapes);
Shape output_shape;
for (int i = 0; i < output_shapes.shape_size(); ++i) {
+ const auto& shape = output_shapes.shape(i);
+ // TOCO doesn't yet properly handle shapes with wildcard dimensions.
+ // TODO(b/113613439): Handle shape inference for unsupported ops that have
+ // shapes with wildcard dimensions.
+ if (HasWildcardDimension(shape)) {
+ LOG(INFO) << "Skipping wildcard output shape(s) for node: "
+ << node.name();
+ op->output_shapes.clear();
+ break;
+ }
const auto status =
- ImportShape(output_shapes.shape(i).dim(), /*input_flat_size=*/nullptr,
- &output_shape);
+ ImportShape(shape.dim(), /*input_flat_size=*/nullptr, &output_shape);
if (!status.ok()) {
return status;
}
@@ -1139,15 +1204,9 @@ tensorflow::Status ConvertPlaceholderOperator(
if (node.attr().count("shape")) {
const auto& shape = GetShapeAttr(node, "shape");
auto num_dims = shape.dim_size();
- bool has_wildcard = false;
- for (std::size_t i = 0; i < num_dims; i++) {
- if (shape.dim(i).size() == -1) {
- has_wildcard = true;
- }
- }
// TODO(b/62716978): This logic needs to be revisted. During dims
// refactoring it is an interim fix.
- if (num_dims > 0 && !has_wildcard) {
+ if (num_dims > 0 && !HasWildcardDimension(shape)) {
auto& dst_array_dims = *array.mutable_shape()->mutable_dims();
dst_array_dims.resize(num_dims);
for (std::size_t i = 0; i < num_dims; i++) {
@@ -2023,6 +2082,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"TopKV2", ConvertTopKV2Operator},
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
{"Unpack", ConvertUnpackOperator},
+ {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>},
});
}
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
index 90e6f698ef..8a236d4444 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -48,6 +49,39 @@ Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&,
namespace {
+Status ImportNode(const NodeDef& node, Model* model) {
+ const auto converter = internal::GetTensorFlowNodeConverterMap();
+ return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), model,
+ converter);
+}
+
+Status ImportNode(const NodeDef& node) {
+ Model model;
+ return ImportNode(node, &model);
+}
+
+NodeDef BuildNode(
+ const std::string& op,
+ const std::vector<std::initializer_list<int>>& output_shapes) {
+ NodeDef node;
+ node.set_op(op);
+ node.set_name("Node1");
+ node.add_input();
+ node.set_input(0, "Node0");
+
+ AttrValue::ListValue* shapes =
+ (*node.mutable_attr())["_output_shapes"].mutable_list();
+ for (const auto& output_shape : output_shapes) {
+ tensorflow::TensorShapeProto* shape = shapes->add_shape();
+ for (int64_t output_shape_dim : output_shape) {
+ auto shape_dim = shape->add_dim();
+ shape_dim->set_size(output_shape_dim);
+ }
+ }
+
+ return node;
+}
+
class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
protected:
ShapeImportTest() {}
@@ -108,12 +142,24 @@ class ShapeImportTest : public ::testing::TestWithParam<tensorflow::DataType> {
SetAttrValue(t, &value_attr);
(*node->mutable_attr())["value"] = value_attr;
}
+};
- Status ImportNode(const NodeDef& node) {
- Model model;
- const auto converter = internal::GetTensorFlowNodeConverterMap();
- return internal::ImportTensorFlowNode(node, TensorFlowImportFlags(), &model,
- converter);
+class TypeImportTest : public ::testing::TestWithParam<
+ std::pair<tensorflow::DataType, ArrayDataType>> {
+ protected:
+ TypeImportTest() {}
+
+ void BuildUnaryNode(const std::string& op_name, tensorflow::DataType dtype,
+ NodeDef* node) {
+ node->set_op(op_name);
+ node->set_name("Node1");
+
+ node->add_input();
+ node->set_input(0, "Node0");
+
+ AttrValue dtype_attr;
+ SetAttrValue(dtype, &dtype_attr);
+ (*node->mutable_attr())["T"] = dtype_attr;
}
};
@@ -166,5 +212,77 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) {
INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest,
::testing::ValuesIn(TestTypes()));
+std::vector<std::pair<tensorflow::DataType, ArrayDataType>> UnaryTestTypes() {
+ return {{DT_FLOAT, ArrayDataType::kFloat},
+ {DT_INT32, ArrayDataType::kInt32},
+ {DT_INT64, ArrayDataType::kInt64}};
+}
+
+TEST_P(TypeImportTest, BasicTypeInference) {
+ NodeDef node;
+ BuildUnaryNode("Atan", GetParam().first, &node);
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_THAT(op->output_data_types, ::testing::ElementsAre(GetParam().second));
+}
+INSTANTIATE_TEST_CASE_P(BasicTypeInference, TypeImportTest,
+ ::testing::ValuesIn(UnaryTestTypes()));
+
+TEST(ImportTest, FailedTypeInference) {
+ // Create a unary op with no Type ("T") annotation.
+ NodeDef node;
+ node.set_op("Atan");
+ node.set_name("Node1");
+ node.add_input();
+ node.set_input(0, "Node0");
+
+ Model model;
+ EXPECT_TRUE(ImportNode(node, &model).ok());
+
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+ ASSERT_TRUE(op->output_data_types.empty());
+}
+
+TEST(ImportTest, UnsupportedOpWithOutputShapes) {
+ // Create an unsupported op with output shapes.
+ Model model;
+ EXPECT_TRUE(ImportNode(BuildNode("Atan", {{1, 2}, {2, 3}}), &model).ok());
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ // The output shapes should be imported.
+ ASSERT_EQ(op->output_shapes.size(), 2);
+ ASSERT_THAT(op->output_shapes[0].dims(), ::testing::ElementsAre(1, 2));
+ ASSERT_THAT(op->output_shapes[1].dims(), ::testing::ElementsAre(2, 3));
+}
+
+TEST(ImportTest, UnsupportedOpWithWildcardOutputShapes) {
+ // Create an unsupported op with wildcard output shapes.
+ Model model;
+ EXPECT_TRUE(ImportNode(BuildNode("Atan", {{-1, 2}}), &model).ok());
+ ASSERT_THAT(model.operators.size(), ::testing::Ge(1));
+ ASSERT_EQ(model.operators[0]->type, OperatorType::kUnsupported);
+ const TensorFlowUnsupportedOperator* op =
+ static_cast<const TensorFlowUnsupportedOperator*>(
+ model.operators[0].get());
+
+ // Wildcard shapes aren't yet supported.
+ ASSERT_TRUE(op->output_shapes.empty());
+}
+
} // namespace
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 2e100e37f6..6e207fdf54 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -150,6 +150,7 @@ enum class OperatorType : uint8 {
kLogicalOr,
kCTCBeamSearchDecoder,
kUnpack,
+ kZerosLike,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -477,6 +478,11 @@ struct DepthwiseConvOperator : Operator {
int stride_height = 0;
int stride_width = 0;
int depth_multiplier = 0;
+ // A dilation_rate of 0 is invalid and this field is an optional attribute.
+ // Thus initializing it to 1 to allow default conv behavior when the
+ // attribute is not present.
+ int dilation_width_factor = 1;
+ int dilation_height_factor = 1;
};
// Depth-to-space transform operator.
@@ -1844,6 +1850,16 @@ struct UnpackOperator : Operator {
ArrayDataType dtype = ArrayDataType::kNone;
};
+// ZerosLike operator:
+//
+// Inputs:
+// inputs[0]: required: the input array
+//
+// TensorFlow equivalent: tf.zeros_like
+struct TensorFlowZerosLikeOperator : Operator {
+ TensorFlowZerosLikeOperator() : Operator(OperatorType::kZerosLike) {}
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
@@ -2068,6 +2084,7 @@ class Model {
}
}
const ArrayMap& GetArrayMap() const { return arrays; }
+ ArrayMap& GetMutableArrayMap() { return arrays; }
int64 ArithmeticOpsCount() const { return ops_count; }
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index eb0f7c443a..ca2a6a19b3 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -107,7 +107,8 @@ class DepthwiseConvolution
ActivationFunction::Serialize(op.fused_activation_function);
return ::tflite::CreateDepthwiseConv2DOptions(
*builder, padding, op.stride_width, op.stride_height,
- op.depth_multiplier, activation_function);
+ op.depth_multiplier, activation_function, op.dilation_width_factor,
+ op.dilation_height_factor);
}
void ReadOptions(const TfLiteOptions& options,
@@ -118,9 +119,18 @@ class DepthwiseConvolution
op->depth_multiplier = options.depth_multiplier();
op->fused_activation_function =
ActivationFunction::Deserialize(options.fused_activation_function());
+ op->dilation_width_factor = options.dilation_w_factor();
+ op->dilation_height_factor = options.dilation_h_factor();
}
- int GetVersion(const Operator& op) const override { return 1; }
+ int GetVersion(const Operator& op) const override {
+ const auto& conv_op = static_cast<const DepthwiseConvOperator&>(op);
+ if (conv_op.dilation_width_factor != 1 ||
+ conv_op.dilation_height_factor != 1) {
+ return 2;
+ }
+ return 1;
+ }
};
class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
@@ -1250,6 +1260,10 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>(fbb.release());
}
+// TODO(wvo): hack to make this code compile with 2 different API versions.
+// Please remove once OS/internal versions are in sync.
+// See hardcoded values in the switch below.
+
void ReadOptions(const flexbuffers::Map& m,
TensorFlowUnsupportedOperator* op) const {
::tensorflow::NodeDef node_def;
@@ -1260,16 +1274,16 @@ class TensorFlowUnsupported : public BaseOperator {
const auto key = keys[i].AsKey();
const auto& value = m[key];
switch (value.GetType()) {
- case flexbuffers::TYPE_STRING:
+ case 5: // flexbuffers::FBT_STRING:
(*attr)[key].set_s(value.AsString().c_str());
break;
- case flexbuffers::TYPE_INT:
+ case 1: // flexbuffers::FBT_INT:
(*attr)[key].set_i(value.AsInt64());
break;
- case flexbuffers::TYPE_FLOAT:
+ case 3: // flexbuffers::FBT_FLOAT:
(*attr)[key].set_f(value.AsFloat());
break;
- case flexbuffers::TYPE_BOOL:
+ case 26: // flexbuffers::FBT_BOOL:
(*attr)[key].set_b(value.AsBool());
if (string(key) == "_output_quantized") {
op->quantized = value.AsBool();
@@ -1278,7 +1292,7 @@ class TensorFlowUnsupported : public BaseOperator {
op->support_output_type_float_in_quantized_op = value.AsBool();
}
break;
- case flexbuffers::TYPE_VECTOR_INT: {
+ case 11: { // flexbuffers::FBT_VECTOR_INT: {
auto* list = (*attr)[key].mutable_list();
const auto& vector = value.AsTypedVector();
for (size_t i = 0; i < vector.size(); i++) {
@@ -1488,6 +1502,10 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
"SQRT", OperatorType::kSqrt));
ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
"RSQRT", OperatorType::kRsqrt));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
+ "SQUARE", OperatorType::kSquare));
+ ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
+ "ZEROS_LIKE", OperatorType::kZerosLike));
return ops;
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 519a3a4e01..0bc591e647 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -144,6 +144,10 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
OperatorType::kLogicalNot);
CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
+ CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
+ OperatorType::kSquare);
+ CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE",
+ OperatorType::kZerosLike);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index a7c17156b1..a08b02485f 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -101,7 +101,6 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveTensorFlowSwitch);
transformations->Add(new ResolveTensorFlowConcat);
transformations->Add(new ResolveMultiplyByZero);
- transformations->Add(new IdentifyDilatedConv);
transformations->Add(new IdentifyL2Normalization);
transformations->Add(new IdentifyL2Pool);
transformations->Add(new IdentifyRelu1);
@@ -282,6 +281,14 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
}
transformations.Add(new ResolveConstantConcatenation);
+ // TODO(b/116063589): TF GraphDef doesn't support dilations on its depthwise
+ // conv, so we need to make sure we don't convert to dilated depthwise conv
+ // when outputing to TF GraphDef.
+ auto* identify_dilated_conv = new IdentifyDilatedConv;
+ if (output_format == TENSORFLOW_GRAPHDEF) {
+ identify_dilated_conv->set_identify_depthwise_conv(false);
+ }
+ transformations.Add(identify_dilated_conv);
RunGraphTransformations(model, "general graph transformations",
transformations);
@@ -367,9 +374,7 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
}
// Deduplicate large constant arrays.
- if (toco_flags.has_dedupe_array_min_size_bytes()) {
- DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
- }
+ DedupeConstantArrays(model, toco_flags.dedupe_array_min_size_bytes());
LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 6ab93d9316..4a1ae35cb5 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -406,6 +406,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
HANDLE_OPERATORTYPENAME_CASE(Unpack)
+ HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 02039922b4..ef4f0fa80d 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -232,6 +232,46 @@ uint64_t BenchmarkTfLiteModel::ComputeInputBytes() {
return total_input_bytes;
}
+void BenchmarkTfLiteModel::PrepareInputsAndOutputs() {
+ auto interpreter_inputs = interpreter->inputs();
+ // Set the values of the input tensors.
+ for (int j = 0; j < inputs.size(); ++j) {
+ const InputLayerInfo& input = inputs[j];
+ int i = interpreter_inputs[j];
+ TfLiteTensor* t = interpreter->tensor(i);
+ std::vector<int> sizes = input.shape;
+
+ // TODO(ahentz): below we ignore the O-th dimension (number of batches).
+ if (t->type == kTfLiteFloat32) {
+ FillRandomValue<float>(
+ interpreter->typed_tensor<float>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
+ } else if (t->type == kTfLiteInt32) {
+ // TODO(yunluli): This is currently only used for handling embedding input
+ // for speech models. Generalize if necessary.
+ FillRandomValue<int32_t>(
+ interpreter->typed_tensor<int32_t>(i),
+ std::vector<int32_t>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<int32_t>(rand()) % 100; });
+ } else if (t->type == kTfLiteUInt8) {
+ FillRandomValue<uint8_t>(
+ interpreter->typed_tensor<uint8_t>(i),
+ std::vector<int>(sizes.begin() + 1, sizes.end()),
+ []() { return static_cast<uint8_t>(rand()) % 255; });
+ } else if (t->type == kTfLiteString) {
+ tflite::DynamicBuffer buffer;
+ FillRandomString(&buffer, sizes, []() {
+ return "we're have some friends over saturday to hang out in the yard";
+ });
+ buffer.WriteToTensor(interpreter->tensor(i));
+ } else {
+ TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
+ << " of type " << t->type;
+ }
+ }
+}
+
void BenchmarkTfLiteModel::Init() {
std::string graph = params_.Get<std::string>("graph");
model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
@@ -305,36 +345,6 @@ void BenchmarkTfLiteModel::Init() {
if (interpreter->AllocateTensors() != kTfLiteOk) {
TFLITE_LOG(FATAL) << "Failed to allocate tensors!";
}
-
- // Set the values of the input tensors.
- for (int j = 0; j < inputs.size(); ++j) {
- const InputLayerInfo& input = inputs[j];
- int i = interpreter_inputs[j];
- TfLiteTensor* t = interpreter->tensor(i);
- std::vector<int> sizes = input.shape;
-
- // TODO(ahentz): below we ignore the O-th dimension (number of batches).
- if (t->type == kTfLiteFloat32) {
- FillRandomValue<float>(
- interpreter->typed_tensor<float>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<float>(rand()) / RAND_MAX - 0.5f; });
- } else if (t->type == kTfLiteUInt8) {
- FillRandomValue<uint8_t>(
- interpreter->typed_tensor<uint8_t>(i),
- std::vector<int>(sizes.begin() + 1, sizes.end()),
- []() { return static_cast<uint8_t>(rand()) % 255; });
- } else if (t->type == kTfLiteString) {
- tflite::DynamicBuffer buffer;
- FillRandomString(&buffer, sizes, []() {
- return "we're have some friends over saturday to hang out in the yard";
- });
- buffer.WriteToTensor(interpreter->tensor(i));
- } else {
- TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
- << " of type " << t->type;
- }
- }
}
void BenchmarkTfLiteModel::RunImpl() {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index 4c4320a998..8541512bc8 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -69,6 +69,9 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
std::vector<int> shape;
};
+ protected:
+ void PrepareInputsAndOutputs() override;
+
private:
#ifdef TFLITE_EXTENDED
std::unique_ptr<EagerDelegate> delegate_;
diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile
index 59bdb10811..16012a3fb1 100644
--- a/tensorflow/contrib/lite/tools/make/Makefile
+++ b/tensorflow/contrib/lite/tools/make/Makefile
@@ -30,6 +30,7 @@ INCLUDES := \
-I$(MAKEFILE_DIR)/../../../../../../ \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen \
+-I$(MAKEFILE_DIR)/downloads/absl \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
-I$(MAKEFILE_DIR)/downloads/farmhash/src \
diff --git a/tensorflow/contrib/lite/tools/make/download_dependencies.sh b/tensorflow/contrib/lite/tools/make/download_dependencies.sh
index 29afa45133..3570f9a38d 100755
--- a/tensorflow/contrib/lite/tools/make/download_dependencies.sh
+++ b/tensorflow/contrib/lite/tools/make/download_dependencies.sh
@@ -35,7 +35,7 @@ GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.g
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip"
FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz"
-FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/v1.8.0.zip"
+FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz"
FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
index b863108aa4..d02d78bf53 100644
--- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -206,6 +206,14 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(
continue;
}
+ // Some tensors may have a null buffer vector, indicating an intermediate
+ // array.
+ if (model->buffers[tensor->buffer]->data.data() == nullptr) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " because it has no allocated buffer.";
+ continue;
+ }
+
TensorInfo tensor_info;
tensor_info.eval_hybrid = eval_hybrid;
tensor_info.op_input_idx = op_input_idx;
diff --git a/tensorflow/contrib/lite/tools/visualize.py b/tensorflow/contrib/lite/tools/visualize.py
index 597dede63b..d7eea79399 100644
--- a/tensorflow/contrib/lite/tools/visualize.py
+++ b/tensorflow/contrib/lite/tools/visualize.py
@@ -202,7 +202,7 @@ class TensorMapper(object):
html += str(i) + " "
html += tensor["name"] + " "
html += str(tensor["type"]) + " "
- html += repr(tensor["shape"]) + "<br>"
+ html += (repr(tensor["shape"]) if "shape" in tensor else "[]") + "<br>"
html += "</span>"
html += repr(x)
html += "</span>"
diff --git a/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
new file mode 100644
index 0000000000..80cdb2f080
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb
@@ -0,0 +1,703 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "6Y8E0lw5eYWm"
+ },
+ "source": [
+ "# Post Training Quantization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "CIGrZZPTZVeO"
+ },
+ "source": [
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ " \u003ctd\u003e\n",
+ " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/tutorials/post_training_quant.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
+ " \u003c/td\u003e\n",
+ "\u003c/table\u003e"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "BTC1rDAuei_1"
+ },
+ "source": [
+ "## Overview\n",
+ "\n",
+ "[TensorFlow Lite](https://www.tensorflow.org/lite/) now supports\n",
+ "converting weights to 8 bit precision as part of model conversion from\n",
+ "tensorflow graphdefs to TFLite's flat buffer format. Weight quantization\n",
+ "achieves a 4x reduction in the model size. In addition, TFLite supports on the\n",
+ "fly quantization and dequantization of activations to allow for:\n",
+ "\n",
+ "1. Using quantized kernels for faster implementation when available.\n",
+ "\n",
+ "2. Mixing of floating-point kernels with quantized kernels for different parts\n",
+ " of the graph.\n",
+ "\n",
+ "Note that the activations are always stored in floating point. For ops that\n",
+ "support quantized kernels, the activations are quantized to 8 bits of precision\n",
+ "dynamically prior to processing and are de-quantized to float precision after\n",
+ "processing. Depending on the model being converted, this can give a speedup over\n",
+ "pure floating point computation.\n",
+ "\n",
+ "In contrast to\n",
+ "[quantization aware training](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize)\n",
+ ", the weights are quantized post training and the activations are quantized dynamically \n",
+ "at inference in this method.\n",
+ "Therefore, the model weights are not retrained to compensate for quantization\n",
+ "induced errors. It is important to check the accuracy of the quantized model to\n",
+ "ensure that the degradation is acceptable.\n",
+ "\n",
+ "In this tutorial, we train an MNIST model from scratch, check its accuracy in\n",
+ "tensorflow and then convert the saved model into a Tensorflow Lite flatbuffer\n",
+ "with weight quantization. We finally check the\n",
+ "accuracy of the converted model and compare it to the original saved model. We\n",
+ "run the training script mnist.py from\n",
+ "[Tensorflow official mnist tutorial](https://github.com/tensorflow/models/tree/master/official/mnist).\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2XsEP17Zelz9"
+ },
+ "source": [
+ "## Building an MNIST model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "dDqqUIZjZjac"
+ },
+ "source": [
+ "### Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "gyqAw1M9lyab"
+ },
+ "outputs": [],
+ "source": [
+ "! pip uninstall -y tensorflow\n",
+ "! pip install -U tf-nightly"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "WsN6s5L1ieNl"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "00U0taBoe-w7"
+ },
+ "outputs": [],
+ "source": [
+ "! git clone --depth 1 https://github.com/tensorflow/models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "4XZPtSh-fUOc"
+ },
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "import os\n",
+ "\n",
+ "if sys.version_info.major \u003e= 3:\n",
+ " import pathlib\n",
+ "else:\n",
+ " import pathlib2 as pathlib\n",
+ "\n",
+ "# Add `models` to the python path.\n",
+ "models_path = os.path.join(os.getcwd(), \"models\")\n",
+ "sys.path.append(models_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "eQ6Q0qqKZogR"
+ },
+ "source": [
+ "### Train and export the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "eMsw_6HujaqM"
+ },
+ "outputs": [],
+ "source": [
+ "saved_models_root = \"/tmp/mnist_saved_model\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "hWSAjQWagIHl"
+ },
+ "outputs": [],
+ "source": [
+ "# The above path addition is not visible to subprocesses, add the path for the subprocess as well.\n",
+ "# Note: channels_last is required here or the conversion may fail. \n",
+ "!PYTHONPATH={models_path} python models/official/mnist/mnist.py --train_epochs=1 --export_dir {saved_models_root} --data_format=channels_last"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "5NMaNZQCkW9X"
+ },
+ "source": [
+ "For the example, we only trained the model for a single epoch, so it only trains to ~96% accuracy.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "xl8_fzVAZwOh"
+ },
+ "source": [
+ "### Convert to a TFLite model\n",
+ "\n",
+ "The `savedmodel` directory is named with a timestamp. Select the most recent one: "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Xp5oClaZkbtn"
+ },
+ "outputs": [],
+ "source": [
+ "saved_model_dir = str(sorted(pathlib.Path(saved_models_root).glob(\"*\"))[-1])\n",
+ "saved_model_dir"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "AT8BgkKmljOy"
+ },
+ "source": [
+ "Using the python `TocoConverter`, the saved model can be converted into a TFLite model.\n",
+ "\n",
+ "First load the model using the `TocoConverter`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "_i8B2nDZmAgQ"
+ },
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "tf.enable_eager_execution()\n",
+ "converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir)\n",
+ "tflite_model = converter.convert()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "F2o2ZfF0aiCx"
+ },
+ "source": [
+ "Write it out to a tflite file:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "vptWZq2xnclo"
+ },
+ "outputs": [],
+ "source": [
+ "tflite_models_dir = pathlib.Path(\"/tmp/mnist_tflite_models/\")\n",
+ "tflite_models_dir.mkdir(exist_ok=True, parents=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Ie9pQaQrn5ue"
+ },
+ "outputs": [],
+ "source": [
+ "tflite_model_file = tflite_models_dir/\"mnist_model.tflite\"\n",
+ "tflite_model_file.write_bytes(tflite_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "7BONhYtYocQY"
+ },
+ "source": [
+ "To quantize the model on export, set the `post_training_quantize` flag:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "g8PUvLWDlmmz"
+ },
+ "outputs": [],
+ "source": [
+ "# Note: If you don't have a recent tf-nightly installed, the\n",
+ "# \"post_training_quantize\" line will have no effect.\n",
+ "tf.logging.set_verbosity(tf.logging.INFO)\n",
+ "converter.post_training_quantize = True\n",
+ "tflite_quant_model = converter.convert()\n",
+ "tflite_model_quant_file = tflite_models_dir/\"mnist_model_quant.tflite\"\n",
+ "tflite_model_quant_file.write_bytes(tflite_quant_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "PhMmUTl4sbkz"
+ },
+ "source": [
+ "Note how the resulting file, with `post_training_quantize` set, is approximately `1/4` the size."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "JExfcfLDscu4"
+ },
+ "outputs": [],
+ "source": [
+ "!ls -lh {tflite_models_dir}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L8lQHMp_asCq"
+ },
+ "source": [
+ "## Run the TFLite models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "-5l6-ciItvX6"
+ },
+ "source": [
+ "We can run the TensorFlow Lite model using the python TensorFlow Lite\n",
+ "Interpreter. \n",
+ "\n",
+ "### load the test data\n",
+ "\n",
+ "First let's load the mnist test data to feed to it:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "eTIuU07NuKFL"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()\n",
+ "images, labels = tf.to_float(mnist_test[0])/255.0, mnist_test[1]\n",
+ "\n",
+ "# Note: If you change the batch size, then use \n",
+ "# `tf.contrib.lite.Interpreter.resize_tensor_input` to also change it for\n",
+ "# the interpreter.\n",
+ "mnist_ds = tf.data.Dataset.from_tensor_slices((images, labels)).batch(1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Ap_jE7QRvhPf"
+ },
+ "source": [
+ "### Load the model into an interpreter"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Jn16Rc23zTss"
+ },
+ "outputs": [],
+ "source": [
+ "interpreter = tf.contrib.lite.Interpreter(model_path=str(tflite_model_file))\n",
+ "interpreter.allocate_tensors()\n",
+ "input_index = interpreter.get_input_details()[0][\"index\"]\n",
+ "output_index = interpreter.get_output_details()[0][\"index\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "J8Pztk1mvNVL"
+ },
+ "outputs": [],
+ "source": [
+ "tf.logging.set_verbosity(tf.logging.DEBUG)\n",
+ "interpreter_quant = tf.contrib.lite.Interpreter(model_path=str(tflite_model_quant_file))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "Afl6yGvWyqAr"
+ },
+ "outputs": [],
+ "source": [
+ "interpreter_quant.allocate_tensors()\n",
+ "input_index = interpreter_quant.get_input_details()[0][\"index\"]\n",
+ "output_index = interpreter_quant.get_output_details()[0][\"index\"]\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "2opUt_JTdyEu"
+ },
+ "source": [
+ "### Test the model on one image"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "AKslvo2kwWac"
+ },
+ "outputs": [],
+ "source": [
+ "for img, label in mnist_ds.take(1):\n",
+ " break\n",
+ "\n",
+ "interpreter.set_tensor(input_index, img)\n",
+ "interpreter.invoke()\n",
+ "predictions = interpreter.get_tensor(output_index)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "XZClM2vo3_bm"
+ },
+ "outputs": [],
+ "source": [
+ "import matplotlib.pylab as plt\n",
+ "\n",
+ "plt.imshow(img[0])\n",
+ "template = \"True:{true}, predicted:{predict}\"\n",
+ "_ = plt.title(template.format(true= str(label[0].numpy()),\n",
+ " predict=str(predictions[0,0])))\n",
+ "plt.grid(False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "LwN7uIdCd8Gw"
+ },
+ "source": [
+ "### Evaluate the models"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "05aeAuWjvjPx"
+ },
+ "outputs": [],
+ "source": [
+ "def eval_model(interpreter, mnist_ds):\n",
+ " total_seen = 0\n",
+ " num_correct = 0\n",
+ "\n",
+ " for img, label in mnist_ds:\n",
+ " total_seen += 1\n",
+ " interpreter.set_tensor(input_index, img)\n",
+ " interpreter.invoke()\n",
+ " predictions = interpreter.get_tensor(output_index)\n",
+ " if predictions == label.numpy():\n",
+ " num_correct += 1\n",
+ "\n",
+ " if total_seen % 500 == 0:\n",
+ " print(\"Accuracy after %i images: %f\" %\n",
+ " (total_seen, float(num_correct) / float(total_seen)))\n",
+ "\n",
+ " return float(num_correct) / float(total_seen)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "DqXBnDfJ7qxL"
+ },
+ "outputs": [],
+ "source": [
+ "print(eval_model(interpreter, mnist_ds))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "Km3cY9ry8ZlG"
+ },
+ "source": [
+ "We can repeat the evaluation on the weight quantized model to obtain:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "-9cnwiPp6EGm"
+ },
+ "outputs": [],
+ "source": [
+ "print(eval_model(interpreter_quant, mnist_ds))\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "L7lfxkor8pgv"
+ },
+ "source": [
+ "\n",
+ "In this example, we have compressed model with no difference in the accuracy."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "M0o1FtmWeKZm"
+ },
+ "source": [
+ "\n",
+ "\n",
+ "## Optimizing an existing model\n",
+ "\n",
+ "We now consider another example. Resnets with pre-activation layers (Resnet-v2) are widely used for vision applications.\n",
+ " Pre-trained frozen graph for resnet-v2-101 is available at the\n",
+ " [Tensorflow Lite model repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md).\n",
+ "\n",
+ "We can convert the frozen graph to a TFLite flatbuffer with quantization by:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "v5p5VcNPjILQ"
+ },
+ "outputs": [],
+ "source": [
+ "archive_path = tf.keras.utils.get_file(\"resnet_v2_101.tgz\", \"https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz\", extract=True)\n",
+ "archive_path = pathlib.Path(archive_path)\n",
+ "archive_dir = str(archive_path.parent)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "-sxnXQuC4ThD"
+ },
+ "source": [
+ "The `info.txt` file lists the input and output names. You can also find them using TensorBoard to visually inspect the graph."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "g_Q_OMEJ4LIc"
+ },
+ "outputs": [],
+ "source": [
+ "! cat {archive_dir}/resnet_v2_101_299_info.txt"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "ujCAFhqm-C6H"
+ },
+ "outputs": [],
+ "source": [
+ "graph_def_file = pathlib.Path(archive_path).parent/\"resnet_v2_101_299_frozen.pb\"\n",
+ "input_arrays = [\"input\"] \n",
+ "output_arrays = [\"output\"]\n",
+ "converter = tf.contrib.lite.TocoConverter.from_frozen_graph(\n",
+ " str(graph_def_file), input_arrays, output_arrays, input_shapes={\"input\":[1,299,299,3]})\n",
+ "converter.post_training_quantize = True\n",
+ "resnet_tflite_file = graph_def_file.parent/\"resnet_v2_101_quantized.tflite\"\n",
+ "resnet_tflite_file.write_bytes(converter.convert())\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab": {},
+ "colab_type": "code",
+ "id": "vhOjeg1x9Knp"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "!ls -lh {archive_dir}/*.tflite"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "qqHLaqFMCjRZ"
+ },
+ "source": [
+ "\n",
+ "The model size reduces from 171 MB to 43 MB.\n",
+ "The accuracy of this model on imagenet can be evaluated using the scripts provided for [TFLite accuracy measurement](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/accuracy/ilsvrc).\n",
+ "\n",
+ "The optimized model top-1 accuracy is 76.8, the same as the floating point model."
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "post-training-quant.ipynb",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
+ },
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
index 4ec539ab42..9c389144ff 100644
--- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py
@@ -61,7 +61,7 @@ def pairwise_distance_np(feature, squared=False):
class ContrastiveLossTest(test.TestCase):
def testContrastive(self):
- with self.test_session():
+ with self.cached_session():
num_data = 10
feat_dim = 6
margin = 1.0
@@ -90,7 +90,7 @@ class ContrastiveLossTest(test.TestCase):
class TripletSemiHardLossTest(test.TestCase):
def testTripletSemiHard(self):
- with self.test_session():
+ with self.cached_session():
num_data = 10
feat_dim = 6
margin = 1.0
@@ -146,7 +146,7 @@ class TripletSemiHardLossTest(test.TestCase):
class LiftedStructLossTest(test.TestCase):
def testLiftedStruct(self):
- with self.test_session():
+ with self.cached_session():
num_data = 10
feat_dim = 6
margin = 1.0
@@ -217,7 +217,7 @@ def convert_to_list_of_sparse_tensor(np_matrix):
class NpairsLossTest(test.TestCase):
def testNpairs(self):
- with self.test_session():
+ with self.cached_session():
num_data = 15
feat_dim = 6
num_classes = 5
@@ -261,7 +261,7 @@ class NpairsLossTest(test.TestCase):
class NpairsLossMultiLabelTest(test.TestCase):
def testNpairsMultiLabelLossWithSingleLabelEqualsNpairsLoss(self):
- with self.test_session():
+ with self.cached_session():
num_data = 15
feat_dim = 6
reg_lambda = 0.02
@@ -290,7 +290,7 @@ class NpairsLossMultiLabelTest(test.TestCase):
self.assertAllClose(loss_npairs, loss_npairs_multilabel)
def testNpairsMultiLabel(self):
- with self.test_session():
+ with self.cached_session():
num_data = 15
feat_dim = 6
num_classes = 10
@@ -527,7 +527,7 @@ class ClusterLossTest(test.TestCase):
def testClusteringLossPAMOff(self):
if not HAS_SKLEARN:
return
- with self.test_session():
+ with self.cached_session():
margin_multiplier = 10.0
embeddings, labels = self._genClusters(n_samples=128, n_clusters=64)
@@ -544,7 +544,7 @@ class ClusterLossTest(test.TestCase):
def testClusteringLossPAMOn(self):
if not HAS_SKLEARN:
return
- with self.test_session():
+ with self.cached_session():
margin_multiplier = 10.0
embeddings, labels = self._genClusters(n_samples=128, n_clusters=64)
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 08de54b8e1..f81a90809a 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -253,6 +253,7 @@ tensorflow/core/kernels/strided_slice_op_inst_5.cc
tensorflow/core/kernels/strided_slice_op_inst_6.cc
tensorflow/core/kernels/strided_slice_op_inst_7.cc
tensorflow/core/kernels/string_join_op.cc
+tensorflow/core/kernels/string_util.cc
tensorflow/core/kernels/tensor_array.cc
tensorflow/core/kernels/tensor_array_ops.cc
tensorflow/core/kernels/tile_functor_cpu.cc
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
index c35e60a554..b1c852c2c6 100644
--- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
@@ -31,6 +31,7 @@ from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _graph_util
from tensorflow.python.framework import importer as _importer
from tensorflow.python.framework import ops as _ops
+from tensorflow.python.platform import tf_logging as _logging
from tensorflow.python.saved_model import constants as _saved_model_constants
from tensorflow.python.training import saver as _saver_lib
from tensorflow.python.util import compat as _compat
@@ -476,6 +477,12 @@ def _add_pruned_collection(base_meta_graph_def, meta_graph_def,
collection.bytes_list.value[:] = [
s for s in base_collection.bytes_list.value
if not _is_removed_mentioned(s, removed_op_names)]
+ _logging.info(
+ 'In collection %s, nodes excluded are: %s', collection_name,
+ sorted([
+ s for s in base_collection.bytes_list.value
+ if _is_removed_mentioned(s, removed_op_names)
+ ]))
elif base_collection.HasField('node_list'):
collection.node_list.value[:] = [
s for s in base_collection.node_list.value
@@ -745,6 +752,9 @@ def meta_graph_transform(
retained_op_names = [_compat.as_str(node.name)
for node in meta_graph_def.graph_def.node]
removed_op_names = set(base_op_names) - set(retained_op_names)
+ _logging.info('Node names in base graph: %s', sorted(base_op_names))
+ _logging.info('Node names retained: %s', sorted(retained_op_names))
+ _logging.info('Node names removed: %s', sorted(removed_op_names))
# Copy saver, excluding any pruned nodes if graph was not frozen.
# TODO(b/63447631): Revisit this once the problem is addressed. Currently
diff --git a/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py b/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
index 1d18d6beff..bed1ecb71c 100644
--- a/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
+++ b/tensorflow/contrib/metrics/python/kernel_tests/histogram_ops_test.py
@@ -31,21 +31,21 @@ class Strict1dCumsumTest(test.TestCase):
"""Test this private function."""
def test_empty_tensor_returns_empty(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant([])
result = histogram_ops._strict_1d_cumsum(tensor, 0)
expected = constant_op.constant([])
np.testing.assert_array_equal(expected.eval(), result.eval())
def test_length_1_tensor_works(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant([3], dtype=dtypes.float32)
result = histogram_ops._strict_1d_cumsum(tensor, 1)
expected = constant_op.constant([3], dtype=dtypes.float32)
np.testing.assert_array_equal(expected.eval(), result.eval())
def test_length_3_tensor_works(self):
- with self.test_session():
+ with self.cached_session():
tensor = constant_op.constant([1, 2, 3], dtype=dtypes.float32)
result = histogram_ops._strict_1d_cumsum(tensor, 3)
expected = constant_op.constant([1, 3, 6], dtype=dtypes.float32)
@@ -58,7 +58,7 @@ class AUCUsingHistogramTest(test.TestCase):
self.rng = np.random.RandomState(0)
def test_empty_labels_and_scores_gives_nan_auc(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([], shape=[0], dtype=dtypes.bool)
scores = constant_op.constant([], shape=[0], dtype=dtypes.float32)
score_range = [0, 1.]
@@ -155,7 +155,7 @@ class AUCUsingHistogramTest(test.TestCase):
from synthetic data.
"""
score_range = [0, 1.] or score_range
- with self.test_session():
+ with self.cached_session():
labels = array_ops.placeholder(dtypes.bool, shape=[num_records])
scores = array_ops.placeholder(dtypes.float32, shape=[num_records])
auc, update_op = histogram_ops.auc_using_histogram(
diff --git a/tensorflow/contrib/metrics/python/metrics/classification_test.py b/tensorflow/contrib/metrics/python/metrics/classification_test.py
index 3d0b81c1be..d6a670f97b 100644
--- a/tensorflow/contrib/metrics/python/metrics/classification_test.py
+++ b/tensorflow/contrib/metrics/python/metrics/classification_test.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import test
class ClassificationTest(test.TestCase):
def testAccuracy1D(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.int32, shape=[None])
labels = array_ops.placeholder(dtypes.int32, shape=[None])
acc = classification.accuracy(pred, labels)
@@ -44,7 +44,7 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
def testAccuracy1DBool(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.bool, shape=[None])
labels = array_ops.placeholder(dtypes.bool, shape=[None])
acc = classification.accuracy(pred, labels)
@@ -54,7 +54,7 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
def testAccuracy1DInt64(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.int64, shape=[None])
labels = array_ops.placeholder(dtypes.int64, shape=[None])
acc = classification.accuracy(pred, labels)
@@ -64,7 +64,7 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
def testAccuracy1DString(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.string, shape=[None])
labels = array_ops.placeholder(dtypes.string, shape=[None])
acc = classification.accuracy(pred, labels)
@@ -87,7 +87,7 @@ class ClassificationTest(test.TestCase):
classification.accuracy(pred, labels)
def testAccuracy1DWeighted(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.int32, shape=[None])
labels = array_ops.placeholder(dtypes.int32, shape=[None])
weights = array_ops.placeholder(dtypes.float32, shape=[None])
@@ -101,7 +101,7 @@ class ClassificationTest(test.TestCase):
self.assertEqual(result, 0.5)
def testAccuracy1DWeightedBroadcast(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
pred = array_ops.placeholder(dtypes.int32, shape=[None])
labels = array_ops.placeholder(dtypes.int32, shape=[None])
weights = array_ops.placeholder(dtypes.float32, shape=[])
@@ -161,7 +161,7 @@ class F1ScoreTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes.int64, seed=2)
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -176,7 +176,7 @@ class F1ScoreTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes.float32)
labels = constant_op.constant(inputs)
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -191,7 +191,7 @@ class F1ScoreTest(test.TestCase):
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run([f1_op])
# Threshold 0 will have around 0.5 precision and 1 recall yielding an F1
@@ -201,7 +201,7 @@ class F1ScoreTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(10000, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes.float32)
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -214,7 +214,7 @@ class F1ScoreTest(test.TestCase):
self.assertAlmostEqual(2 * 0.5 * 1 / (1 + 0.5), f1.eval(), places=2)
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -228,7 +228,7 @@ class F1ScoreTest(test.TestCase):
self.assertAlmostEqual(1.0, f1.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -242,7 +242,7 @@ class F1ScoreTest(test.TestCase):
self.assertAlmostEqual(1.0, f1.eval(), places=5)
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes.float32)
labels = array_ops.zeros([4])
f1, f1_op = classification.f1_score(predictions, labels, num_thresholds=3)
@@ -300,7 +300,7 @@ class F1ScoreTest(test.TestCase):
f1, f1_op = classification.f1_score(tf_labels, tf_predictions,
num_thresholds=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in range(num_batches):
sess.run([f1_op])
diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
index fcce52a07a..a5621b44cd 100644
--- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
+++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py
@@ -66,10 +66,11 @@ class LossScaleOptimizer(optimizer.Optimizer):
# Choose a loss scale manager which decides how to pick the right loss scale
# throughout the training process.
- loss_scale_manger = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
+ loss_scale_manager = tf.contrib.mixed_precision.FixedLossScaleManager(5000)
# Wraps the original optimizer in a LossScaleOptimizer.
- loss_scale_optimizer = LossScaleOptimizer(opt, loss_scale_manager)
+ loss_scale_optimizer =
+ tf.contrib.mixed_precision.LossScaleOptimizer(opt, loss_scale_manager)
# Call minimize() on the loss scale optimizer.
train_op = loss_scale_optimizer.minimize(loss)
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
index 6a7f5efecd..b9967fe76d 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc
@@ -136,8 +136,8 @@ void MPIRemoteRendezvous::RecvFromRemoteAsync(
MPIRendezvousMgr* mgr =
reinterpret_cast<MPIRendezvousMgr*>(this->rendezvous_mgr_);
- mgr->QueueRequest(parsed.FullKey().ToString(), step_id_,
- std::move(request_call), rendezvous_call);
+ mgr->QueueRequest(string(parsed.FullKey()), step_id_, std::move(request_call),
+ rendezvous_call);
}
MPIRemoteRendezvous::~MPIRemoteRendezvous() {}
@@ -258,7 +258,7 @@ void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
std::function<MPISendTensorCall*()> res = std::bind(
send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call);
- SendQueueEntry req(parsed.FullKey().ToString().c_str(), std::move(res));
+ SendQueueEntry req(string(parsed.FullKey()), std::move(res));
this->QueueSendRequest(req);
diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
index 5596601ddb..90140fcab3 100644
--- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
+++ b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h
@@ -71,7 +71,7 @@ class MPISendTensorCall {
void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id,
const bool is_dead) {
- mRes_.set_key(parsed.FullKey().ToString());
+ mRes_.set_key(string(parsed.FullKey()));
mRes_.set_step_id(step_id);
mRes_.mutable_response()->set_is_dead(is_dead);
mRes_.mutable_response()->set_send_start_micros(
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index 62996d1fd8..9a9d480260 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -31,9 +31,11 @@ tf_custom_op_library(
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
]),
- deps = if_cuda([
+ deps = [] + if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:protos_all_proto_text",
]),
)
@@ -57,32 +59,31 @@ tf_cuda_cc_test(
"notap",
],
deps =
- [
+ if_cuda([
+ "@local_config_nccl//:nccl",
"//tensorflow/core:cuda",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
- "@local_config_nccl//:nccl",
- ],
+ ]),
)
tf_kernel_library(
name = "nccl_kernels",
- srcs = [
+ srcs = if_cuda([
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
"kernels/nccl_rewrite.cc",
- ],
- deps = [
+ ]),
+ deps = if_cuda([
+ "@local_config_nccl//:nccl",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
- "//tensorflow/core:proto_text",
"//tensorflow/core:stream_executor",
- "@local_config_nccl//:nccl",
- ],
+ ]),
alwayslink = 1,
)
diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
index 4676e937e5..06ff86e6d8 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 2e4d61d931..f4ac70eb1a 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -16,6 +16,7 @@ py_library(
"__init__.py",
"python/training/adamax.py",
"python/training/addsign.py",
+ "python/training/agn_optimizer.py",
"python/training/drop_stale_gradient_optimizer.py",
"python/training/elastic_average_optimizer.py",
"python/training/external_optimizer.py",
@@ -246,6 +247,27 @@ tf_py_test(
)
tf_py_test(
+ name = "agn_optimizer_test",
+ srcs = ["python/training/agn_optimizer_test.py"],
+ additional_deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//third_party/py/numpy",
+ ],
+ tags = [
+ "notap", # this test launches a local server
+ ],
+)
+
+tf_py_test(
name = "elastic_average_optimizer_test",
srcs = ["python/training/elastic_average_optimizer_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index ad7d7cfa6e..c7ea68efa9 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -1,4 +1,4 @@
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.opt.python.training.adamax import *
from tensorflow.contrib.opt.python.training.addsign import *
+from tensorflow.contrib.opt.python.training.agn_optimizer import *
from tensorflow.contrib.opt.python.training.drop_stale_gradient_optimizer import *
from tensorflow.contrib.opt.python.training.elastic_average_optimizer import *
from tensorflow.contrib.opt.python.training.external_optimizer import *
@@ -60,6 +61,8 @@ _allowed_symbols = [
'VariableClippingOptimizer',
'MultitaskOptimizerWrapper',
'clip_gradients_by_global_norm',
+ 'AGNOptimizer',
+ 'AGNCustomGetter',
'ElasticAverageOptimizer',
'ElasticAverageCustomGetter',
'ModelAverageOptimizer',
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer.py b/tensorflow/contrib/opt/python/training/agn_optimizer.py
new file mode 100644
index 0000000000..9d8bab8d33
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/agn_optimizer.py
@@ -0,0 +1,262 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import session_run_hook
+
+GLOBAL_VARIABLE_NAME = 'global_center_variable'
+GRAD_VARIABLE_NAME = 'grad_variable'
+
+
+class AGNCustomGetter(object):
+ """Custom_getter class is used to do:
+
+ 1. Change trainable variables to local collection and place them at worker
+ device
+ 2. Generate global variables(global center variables)
+ 3. Generate grad variables(gradients) which record the gradients sum
+ and place them at worker device
+ Notice that the class should be used with tf.replica_device_setter,
+ so that the global center variables and global step variable can be placed
+ at ps device.
+ """
+
+ def __init__(self, worker_device):
+ """
+ Args:
+ worker_device: put the grad_variables on worker device
+ """
+ self._worker_device = worker_device
+ self._global_map = {}
+ self._grad_map = {}
+
+ def __call__(self, getter, name, trainable, collections, *args, **kwargs):
+ if trainable:
+ with ops.device(self._worker_device):
+ local_var = getter(
+ name,
+ trainable=True,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+ if kwargs['reuse'] == True:
+ return local_var
+ global_center_variable = getter(
+ name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ *args,
+ **kwargs)
+
+ with ops.device(self._worker_device):
+ grad_variable = getter(
+ name='%s/%s' % (GRAD_VARIABLE_NAME, name),
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ *args,
+ **kwargs)
+ if kwargs['partitioner'] is None:
+ self._grad_map[local_var] = grad_variable
+ self._global_map[local_var] = global_center_variable
+ else:
+ v_list = list(local_var)
+ for i in range(len(v_list)):
+ self._grad_map[v_list[i]] = list(grad_variable)[i]
+ self._global_map[v_list[i]] = list(global_center_variable)[i]
+ return local_var
+ else:
+ return getter(
+ name, trainable=trainable, collections=collections, *args, **kwargs)
+
+
+class AGNOptimizer(optimizer.Optimizer):
+ """Wrapper that implements the Accumulated GradientNormalization algorithm.
+
+ Reference:
+ Accumulated Gradient Normalization: Joeri Hermans ACML2017
+ https://arxiv.org/abs/1710.02368
+ """
+
+ def __init__(self,
+ optimizer,
+ num_worker,
+ custom_getter,
+ communication_period=10,
+ use_locking=True,
+ name='AGNOptimizer'):
+ """Construct a new AGN optimizer.
+
+ Args:
+ optimizer: input optimizer, can be sgd/momentum/adam etc.
+ num_worker: The number of workers
+ custom_getter: The AGNCustomGetter
+ communication_period: An int point value to controls the frequency of the
+ communication between every worker and the ps.
+ use_locking: If True use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "AGNOptimizer".
+ """
+ super(AGNOptimizer, self).__init__(use_locking, name)
+ self._opt = optimizer
+ self._num_worker = num_worker
+ self._period = communication_period
+ self._global_map = custom_getter._global_map
+ self._grad_map = custom_getter._grad_map
+ self._local_step = variable_scope.get_variable(
+ initializer=0,
+ trainable=False,
+ collections=[ops.GraphKeys.LOCAL_VARIABLES],
+ name='local_step')
+ self._opt._prepare()
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to global variables.
+
+ This is the second part of `minimize()`. It returns an `Operation` that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the variables
+ have been updated.
+ name: Optional name for the returned operation. Default to the name
+ passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+ """
+ local_vars = [v for g, v in grads_and_vars if g is not None]
+ grads = [g for g, v in grads_and_vars if g is not None]
+
+ def _variable_creator(next_creator, collections, **kwargs):
+ if not collections:
+ collections = [ops.GraphKeys.LOCAL_VARIABLES]
+ elif ops.GraphKeys.GLOBAL_VARIABLES in collections:
+ collections = list(collections)
+ collections.append(ops.GraphKeys.LOCAL_VARIABLES)
+ collections.remove(ops.GraphKeys.GLOBAL_VARIABLES)
+ return next_creator(collections=collections, **kwargs)
+
+ # theta = theta - lr * grad
+ with variable_scope.variable_creator_scope(_variable_creator):
+ local_update_op = self._opt.apply_gradients(grads_and_vars)
+
+ # a = a + grad
+ update_ops = []
+ update_ops.append(local_update_op)
+ grad_vars = [self._grad_map[var] for var in local_vars]
+ for g, grad_var in zip(grads, grad_vars):
+ update_ops.append(state_ops.assign_add(grad_var, g))
+
+ global_center_vars = [self._global_map[var] for var in local_vars]
+
+ # update global variables.
+ def _Update_global_variables():
+ global_norm = []
+ # a = a / t
+ for g in grad_vars:
+ global_norm.append(state_ops.assign(g, g / self._period))
+ # apply
+ with ops.control_dependencies(global_norm):
+ apply_global_op = self._opt.apply_gradients(
+ zip(grad_vars, global_center_vars))
+
+ # pull
+ with ops.control_dependencies([apply_global_op]):
+ update_ops = []
+ if global_step:
+ with ops.colocate_with(global_step):
+ update_ops.append(state_ops.assign_add(global_step, 1))
+
+ for lvar in local_vars:
+ g_val = self._global_map[lvar].read_value()
+ update_ops.append(state_ops.assign(lvar, g_val))
+ for grad_var in grad_vars:
+ update_ops.append(
+ state_ops.assign(grad_var, array_ops.zeros_like(grad_var)))
+ variable_update = control_flow_ops.group(*(update_ops))
+ return variable_update
+
+ local_update = state_ops.assign_add(
+ self._local_step, 1, name='local_step_update').op
+
+ with ops.control_dependencies([local_update]):
+ condition = math_ops.equal(
+ math_ops.mod(self._local_step, self._period), 0)
+ with ops.control_dependencies(update_ops):
+ conditional_update = control_flow_ops.cond(
+ condition, _Update_global_variables, control_flow_ops.no_op)
+ return conditional_update
+
+ def get_init_op(self, task_index):
+ """Returns the op to let all the local variables and local center
+
+ variables equal to the global center variables before the training begins
+ """
+ init_ops = []
+ local_vars = variables.trainable_variables()
+ global_center_vars = [self._global_map[var] for var in local_vars]
+ grad_vars = [self._grad_map[var] for var in local_vars]
+ if not (local_vars and global_center_vars and grad_vars):
+ raise ValueError('The lists of local_variables, global_center_variables,'
+ 'grad_center_variables should not be empty')
+ for lvar, gc_var in zip(local_vars, global_center_vars):
+ init_ops.append(state_ops.assign(lvar, gc_var))
+ for g in grad_vars:
+ init_ops.append(state_ops.assign(g, array_ops.zeros_like(g)))
+ init_op = control_flow_ops.group(*(init_ops))
+ return init_op
+
+ def make_session_run_hook(self, is_chief, task_index):
+ """Creates a hook to handle AGNOptimizerHook ops such as initialization."""
+ return _AGNOptimizerHook(self, is_chief, task_index)
+
+
+class _AGNOptimizerHook(session_run_hook.SessionRunHook):
+
+ def __init__(self, agn_optimizer, is_chief, task_index):
+ """Creates hook to handle AGNOptimizer initialization ops.
+
+ Args:
+ agn_optimizer: `AGNOptimizer` which this hook will initialize.
+ is_chief: `Bool`, whether is this a chief replica or not.
+ task_index: int, task_index of worker
+ """
+ self._agn_optimizer = agn_optimizer
+ self._is_chief = is_chief
+ self._task_index = task_index
+
+ def begin(self):
+ self._local_init_op = variables.local_variables_initializer()
+ self._global_init_op = None
+ if self._is_chief:
+ self._global_init_op = variables.global_variables_initializer()
+ self._variable_init_op = self._agn_optimizer.get_init_op(self._task_index)
+
+ def after_create_session(self, session, coord):
+ """Run initialization ops"""
+ session.run(self._variable_init_op)
diff --git a/tensorflow/contrib/opt/python/training/agn_optimizer_test.py b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
new file mode 100644
index 0000000000..d3da290bdb
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/agn_optimizer_test.py
@@ -0,0 +1,281 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for EAOptimizer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import portpicker
+
+from tensorflow.contrib.opt.python.training import agn_optimizer
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adam
+from tensorflow.python.training import device_setter
+from tensorflow.python.training import server_lib
+from tensorflow.python.training import training
+from tensorflow.python.training import training_util
+
+
+
+def create_local_cluster(num_workers, num_ps, protocol="grpc"):
+ """Create local GRPC servers and return them."""
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+ cluster_dict = {
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
+ }
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ workers = [
+ server_lib.Server(
+ cs, job_name="worker", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_workers)
+ ]
+ ps_servers = [
+ server_lib.Server(
+ cs, job_name="ps", protocol=protocol, task_index=ix, start=True)
+ for ix in range(num_ps)
+ ]
+
+ return cluster_dict, workers, ps_servers
+
+
+# Creates the workers and return their sessions, graphs, train_ops.
+# Cheif worker will update at last
+def _get_workers(num_workers, period, workers, num_ps=1):
+ sessions = []
+ graphs = []
+ train_ops = []
+ for worker_id in range(num_workers):
+ graph = ops.Graph()
+ is_chief = (worker_id == 0)
+ with graph.as_default():
+ worker_device = "/job:worker/task:%d/cpu:0" % (worker_id)
+ ps_device = device_setter.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/task:0/cpu:0",
+ ps_tasks=1)
+ agn_getter = agn_optimizer.AGNCustomGetter(worker_device=worker_device)
+ with variable_scope.variable_scope(
+ "", custom_getter=agn_getter), ops.device(ps_device):
+ global_step = training_util.get_or_create_global_step()
+ var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
+ var_1 = variable_scope.get_variable(initializer=0.5, name="v1")
+ if num_ps > 1:
+ with variable_scope.variable_scope(
+ "",
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_ps, axis=0),
+ custom_getter=agn_getter), ops.device(ps_device):
+
+ partition_var = variable_scope.get_variable(
+ "partition_var",
+ shape=[2, 4],
+ initializer=init_ops.zeros_initializer)
+ part_0 = list(partition_var)[0]
+ part_1 = list(partition_var)[1]
+
+ with ops.device("/job:worker/task:" + str(worker_id)):
+ grads_0 = constant_op.constant(-1.0)
+ grads_1 = constant_op.constant(-1.0)
+ grads_part_0 = constant_op.constant([[-1., -1., -1., -1.]])
+ grads_part_1 = constant_op.constant([[-1., -1., -1., -1.]])
+
+ optimizer = \
+ adam.AdamOptimizer(learning_rate=0.1, beta1=0.0, beta2=0.0)
+ opt = agn_optimizer.AGNOptimizer(
+ optimizer,
+ num_worker=num_workers,
+ communication_period=period,
+ custom_getter=agn_getter)
+ if num_ps == 1:
+ train_op = [
+ opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]),
+ global_step)
+ ]
+ else:
+ train_op = [
+ opt.apply_gradients(
+ ([grads_0, var_0], [grads_1, var_1], [grads_part_0, part_0],
+ [grads_part_1, part_1]), global_step)
+ ]
+ hook = opt.make_session_run_hook(is_chief, worker_id)
+ # Creates MonitoredSession
+ sess = training.MonitoredTrainingSession(
+ workers[worker_id].target, hooks=[hook])
+
+ sessions.append(sess)
+ graphs.append(graph)
+ train_ops.append(train_op)
+
+ return sessions, graphs, train_ops
+
+
+class AGNOptimizerTest(test.TestCase):
+
+ def _run(self, train_op, sess):
+ sess.run(train_op)
+
+ def test1Workers2Period(self):
+ num_workers = 1
+ communication_period = 4
+ num_ps = 1
+ _, workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps)
+
+ sessions, graphs, train_ops = _get_workers(num_workers,
+ communication_period, workers)
+
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
+ global_step = training_util.get_global_step(graphs[0])
+ var_0_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0")
+ var_1_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0")
+
+ # verify adam/beta variables not in global collection
+ with graphs[0].as_default():
+ for ele in variables.global_variables():
+ self.assertTrue(ele.op.name.find("beta") < 0)
+ if ele.op.name.find("global_center_variable") < 0:
+ self.assertTrue(ele.op.name.find("Adam") < 0)
+
+ # Verify the initialized value.
+ self.assertAllEqual(0.0, sessions[0].run(var_0))
+ self.assertAllEqual(0.5, sessions[0].run(var_1))
+ self.assertAllEqual(0.0, sessions[0].run(var_0_g))
+ self.assertAllEqual(0.5, sessions[0].run(var_1_g))
+ self.assertAllEqual(0, sessions[0].run(global_step))
+ # step 0
+ sessions[0].run(train_ops[0])
+ self.assertNear(0.1, sessions[0].run(var_0), 1e-6)
+ self.assertNear(0.6, sessions[0].run(var_1), 1e-6)
+ self.assertAllEqual(0.0, sessions[0].run(var_0_g))
+ self.assertAllEqual(0.5, sessions[0].run(var_1_g))
+ self.assertAllEqual(0, sessions[0].run(global_step))
+
+ # 2 & 3
+ sessions[0].run(train_ops[0])
+ sessions[0].run(train_ops[0])
+ self.assertNear(0.3, sessions[0].run(var_0), 1e-6)
+ self.assertNear(0.8, sessions[0].run(var_1), 1e-6)
+
+ # 4
+ sessions[0].run(train_ops[0])
+ # pull
+ self.assertAllEqual(sessions[0].run(var_0), sessions[0].run(var_0_g))
+ self.assertAllEqual(sessions[0].run(var_1), sessions[0].run(var_1_g))
+ self.assertNear(0.1, sessions[0].run(var_0), 1e-6)
+ self.assertNear(0.6, sessions[0].run(var_1), 1e-6)
+
+ sessions[0].run(train_ops[0])
+ sessions[0].run(train_ops[0])
+ sessions[0].run(train_ops[0])
+ sessions[0].run(train_ops[0])
+ self.assertAllEqual(sessions[0].run(var_0), sessions[0].run(var_0_g))
+ self.assertAllEqual(sessions[0].run(var_1), sessions[0].run(var_1_g))
+ self.assertNear(0.2, sessions[0].run(var_0), 1e-6)
+ self.assertNear(0.7, sessions[0].run(var_1), 1e-6)
+
+ def test2Worker1Period(self):
+ num_workers = 2
+ communication_period = 1
+ num_ps = 2
+ _, workers, _ = create_local_cluster(num_workers=num_workers, num_ps=num_ps)
+
+ sessions, graphs, train_ops = _get_workers(
+ num_workers, communication_period, workers, num_ps=2)
+
+ var_0 = graphs[0].get_tensor_by_name("v0:0")
+ var_1 = graphs[0].get_tensor_by_name("v1:0")
+
+ var_0_1 = graphs[1].get_tensor_by_name("v0:0")
+ var_1_1 = graphs[1].get_tensor_by_name("v1:0")
+
+ var_0_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME + "/v0:0")
+ var_1_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME + "/v1:0")
+ part_0_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME +
+ "/partition_var/part_0:0")
+ part_1_g = graphs[0].get_tensor_by_name(
+ agn_optimizer.GLOBAL_VARIABLE_NAME +
+ "/partition_var/part_1:0")
+
+ # Verify the initialized value.
+ self.assertAllEqual(0.0, sessions[0].run(var_0))
+ self.assertAllEqual(0.5, sessions[0].run(var_1))
+ self.assertAllEqual(0.0, sessions[1].run(var_0_1))
+ self.assertAllEqual(0.5, sessions[1].run(var_1_1))
+ self.assertAllEqual(0.0, sessions[0].run(var_0_g))
+ self.assertAllEqual(0.5, sessions[0].run(var_1_g))
+
+ # verify each step
+ sessions[0].run(train_ops[0])
+ self.assertNear(0.1, sessions[0].run(var_0_g), 1e-6)
+ self.assertNDArrayNear([0.1, 0.1, 0.1, 0.1], sessions[0].run(part_0_g),
+ 1e-6)
+ self.assertNDArrayNear([0.1, 0.1, 0.1, 0.1], sessions[0].run(part_1_g),
+ 1e-6)
+
+ sessions[1].run(train_ops[1])
+ self.assertNear(0.2, sessions[0].run(var_0_g), 1e-6)
+ self.assertNDArrayNear([0.2, 0.2, 0.2, 0.2], sessions[0].run(part_0_g),
+ 1e-6)
+ self.assertNDArrayNear([0.2, 0.2, 0.2, 0.2], sessions[0].run(part_1_g),
+ 1e-6)
+
+ sessions[0].run(train_ops[0])
+ sessions[1].run(train_ops[1])
+
+ sessions[0].run(train_ops[0])
+ sessions[1].run(train_ops[1])
+ self.assertNear(0.6, sessions[0].run(var_0_g), 1e-6)
+ self.assertNDArrayNear([0.6, 0.6, 0.6, 0.6], sessions[0].run(part_0_g),
+ 1e-6)
+ self.assertNDArrayNear([0.6, 0.6, 0.6, 0.6], sessions[0].run(part_1_g),
+ 1e-6)
+
+ def testAGNCustomGetter(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "ps": ["ps0:2222", "ps1:2222"],
+ "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
+ })
+ agn_getter = agn_optimizer.AGNCustomGetter(
+ worker_device="/job:worker/task:0")
+ with ops.device(
+ device_setter.replica_device_setter(cluster=cluster_spec,
+ worker_device="/job:worker/task:0",
+ ps_device="/job:ps")), \
+ variable_scope.variable_scope("", custom_getter=agn_getter):
+ v = variable_scope.get_variable(initializer=[1, 2], name="v")
+ w = variable_scope.get_variable(initializer=[2, 1], name="w")
+ v_g, w_g = agn_getter._global_map[v], agn_getter._global_map[w]
+ self.assertDeviceEqual("/job:worker/task:0", v.device)
+ self.assertDeviceEqual("job:ps/task:0", v_g.device)
+ self.assertDeviceEqual("/job:worker/task:0", w.device)
+ self.assertDeviceEqual("job:ps/task:1", w_g.device)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index f08ffaa36f..089ecf597d 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -236,7 +236,7 @@ class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
opt.get_slot(var=var0, name="m").name)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -249,7 +249,7 @@ class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -286,7 +286,7 @@ class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index c333d1e089..dab1e02716 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -31,7 +31,7 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
or this
- [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf).
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
def __init__(self, learning_rate, initial_accumulator_value=0.1,
@@ -64,18 +64,17 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
def _create_vars(self, var_list, state):
for v in var_list:
- # TODO(isaprykin): Delete colocate_with(v) from other optimizers and
- # confirm that colocation will happen anyway.
dtype = v.dtype.base_dtype
if v.get_shape().is_fully_defined():
init = init_ops.constant_initializer(self._initial_accumulator_value,
dtype=dtype)
else:
- # Use a Tensor instead of initializer if variable does not have static
- # shape.
- init_constant = gen_array_ops.fill(
- array_ops.shape(v), self._initial_accumulator_value)
- init = math_ops.cast(init_constant, dtype)
+ def init(v=v, dtype=dtype):
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
"accumulator")
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD
index 72ea777ca7..d50b52b8ff 100644
--- a/tensorflow/contrib/predictor/BUILD
+++ b/tensorflow/contrib/predictor/BUILD
@@ -27,7 +27,7 @@ py_library(
":contrib_estimator_predictor",
":core_estimator_predictor",
":saved_model_predictor",
- "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:estimator_py",
],
)
@@ -89,7 +89,6 @@ py_library(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/saved_model:signature_constants",
],
diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py
index 95da6d04ed..03399396df 100644
--- a/tensorflow/contrib/predictor/saved_model_predictor.py
+++ b/tensorflow/contrib/predictor/saved_model_predictor.py
@@ -23,7 +23,6 @@ import logging
from tensorflow.contrib.predictor import predictor
from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import loader
@@ -68,23 +67,19 @@ def _get_signature_def(signature_def_key, export_dir, tags):
metagraph_def = get_meta_graph_def(export_dir, tags)
try:
- signature_def = signature_def_utils.get_signature_def_by_key(
- metagraph_def,
+ signature_def = metagraph_def.signature_def[signature_def_key]
+ except KeyError as e:
+ formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format(
signature_def_key)
- except ValueError as e:
try:
- formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format(
- signature_def_key)
- signature_def = signature_def_utils.get_signature_def_by_key(
- metagraph_def, formatted_key)
-
- logging.warning('Could not find signature def "%s". '
- 'Using "%s" instead', signature_def_key, formatted_key)
- except ValueError:
+ signature_def = metagraph_def.signature_def[formatted_key]
+ except KeyError:
raise ValueError(
'Got signature_def_key "{}". Available signatures are {}. '
'Original error:\n{}'.format(
signature_def_key, list(metagraph_def.signature_def), e))
+ logging.warning('Could not find signature def "%s". '
+ 'Using "%s" instead', signature_def_key, formatted_key)
return signature_def
diff --git a/tensorflow/contrib/quantization/README.md b/tensorflow/contrib/quantization/README.md
index 359950aaf3..826e8db2d3 100644
--- a/tensorflow/contrib/quantization/README.md
+++ b/tensorflow/contrib/quantization/README.md
@@ -2,6 +2,6 @@ The contrib/quantization package exposes a few TensorFlow quantization operation
If you are looking for quantized training rewrites that allow for training
quantized models that work with
-[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/), you should look at
+[TensorFlow Lite](https://www.tensorflow.org/lite/), you should look at
the [contrib/quantize](https://www.tensorflow.org/api_docs/python/tf/contrib/quantize)
package.
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index c59f667f6a..23e3a25d71 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -20,9 +20,13 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":common",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md
index 27a933c0f9..0ab19c91bb 100644
--- a/tensorflow/contrib/quantize/README.md
+++ b/tensorflow/contrib/quantize/README.md
@@ -1,65 +1,155 @@
-# Quantized Training Rewrites
+# Quantization-aware training
-tf.contrib.quantize provides tools for transforming graphs to include ops to
-model quantization of weights, biases and activations during both training and
-inference. The details of the transformation implemented in this package is
-described here [1].
+Quantization-aware model training ensures that the forward pass matches precision
+for both training and inference. There are two aspects to this:
-This is done using the
-[fake quantization op](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization).
+* Operator fusion at inference time are accurately modeled at training time.
+* Quantization effects at inference are modeled at training time.
-Literature has shown that fixed point networks provide comparable performance to
-floating point networks [2]. This is achieved by modeling the quantization
-operation during training in both the forward and backward passes.
-The fake quantization operator achieves this by modeling the quantizer as a pass
-through estimator [3]. Note that during back propagation, the parameters are
+For efficient inference, TensorFlow combines batch normalization with the preceding
+convolutional and fully-connected layers prior to quantization by
+[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}.
+
+The quantization error is modeled using [fake quantization](../api_guides/python/array_ops.md#Fake_quantization)
+nodes to simulate the effect of quantization in the forward and backward passes. The
+forward-pass models quantization, while the backward-pass models quantization as a
+straight-through estimator. Both the forward- and backward-pass simulate the quantization
+of weights and activations. Note that during back propagation, the parameters are
updated at high precision as this is needed to ensure sufficient precision in
-accumulating tiny adjustments to the parameters. However, for the forward pass,
-the parameters and activations are quantized to the desired lower precision.
+accumulating tiny adjustments to the parameters.
+
-## How to use the Rewrites
+Additionally, the minimum and maximum values for activations are determined
+during training. This allows a model trained with quantization in the loop to be
+converted to a fixed point inference model with little effort, eliminating the
+need for a separate calibration step.
-tf.contrib.quantize provides two rewrites, one to train for quantization and
-one to create a [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/)
-compatible eval graph.
+Since it's difficult to add these fake quantization operations to all the
+required locations in the model, there's a function available that rewrites the
+training graph. To create a fake quantized training graph:
```
# Build forward pass of model.
-…
loss = tf.losses.get_total_loss()
-# Call the training rewrite which rewrites the graph in-place with FakeQuantization nodes
-# and folds batchnorm for training.
-# It is often needed to finetune a floating point model for quantization with this training tool.
-# When training from scratch, quant_delay can be used to activate quantization after
-# training to convergence with the float graph, effectively finetuning the model.
-tf.contrib.quantize.create_training_graph(quant_delay=2000000)
+# Call the training rewrite which rewrites the graph in-place with
+# FakeQuantization nodes and folds batchnorm for training. It is
+# often needed to fine tune a floating point model for quantization
+# with this training tool. When training from scratch, quant_delay
+# can be used to activate quantization after training to converge
+# with the float graph, effectively fine-tuning the model.
+g = tf.get_default_graph()
+tf.contrib.quantize.create_training_graph(input_graph=g,
+ quant_delay=2000000)
# Call backward pass optimizer as usual.
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
optimizer.minimize(loss)
```
-Additionally, the rewritten eval graph is non-trivially different from the
-training graph due the effects of quantization on batch normalization. Thus,
-we offer a separate rewrite for the eval_graph.
+The rewritten *eval graph* is non-trivially different from the *training graph*
+since the quantization ops affect the batch normalization step. Because of this,
+we've added a separate rewrite for the *eval graph*:
```
# Build eval model
-…
-logits = tf.nn.softmax_cross_entropy_with_logits(...)
+logits = tf.nn.softmax_cross_entropy_with_logits_v2(...)
-# Call the eval rewrite which rewrites the graph in-place with FakeQuantization nodes
-# and fold batchnorm for eval.
-tf.contrib.quantize.create_eval_graph()
+# Call the eval rewrite which rewrites the graph in-place with
+# FakeQuantization nodes and fold batchnorm for eval.
+g = tf.get_default_graph()
+tf.contrib.quantize.create_eval_graph(input_graph=g)
-# Save the checkpoint and eval graph proto to disk for freezing and providing to TFLite.
+# Save the checkpoint and eval graph proto to disk for freezing
+# and providing to TFLite.
with open(eval_graph_file, ‘w’) as f:
f.write(str(g.as_graph_def()))
saver = tf.train.Saver()
saver.save(sess, checkpoint_name)
```
+Methods to rewrite the training and eval graphs are an active area of research
+and experimentation. Although rewrites and quantized training might not work or
+improve performance for all models, we are working to generalize these techniques.
+
+
+## Generating fully-quantized models
+
+The previously demonstrated after-rewrite eval graph only *simulates*
+quantization. To generate real fixed-point computations from a trained
+quantization model, convert it to a fixed-point kernel. TensorFlow Lite supports
+this conversion from the graph resulting from `create_eval_graph`.
+
+First, create a frozen graph that will be the input for the TensorFlow Lite
+toolchain:
+
+```
+freeze_graph \
+ --input_graph=eval_graph_def.pb \
+ --input_checkpoint=checkpoint \
+ --output_graph=frozen_eval_graph.pb --output_node_names=outputs
+```
+
+Provide this to the TensorFlow Lite Optimizing Converter (TOCO) to get a
+fully-quantized TensorFlow Lite model:
+
+```
+toco \
+ --input_file=frozen_eval_graph.pb \
+ --output_file=tflite_model.tflite \
+ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
+ --inference_type=QUANTIZED_UINT8 \
+ --input_shape="1,224, 224,3" \
+ --input_array=input \
+ --output_array=outputs \
+ --std_value=127.5 --mean_value=127.5
+```
+
+See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/).
+
+
+## Quantized accuracy results
+
+The following are results of trainiing some popular CNN models (Mobilenet-v1,
+Mobilenet-v2, and Inception-v3) using this tool:
+
+<figure>
+ <table>
+ <tr>
+ <th>Model</th>
+ <th>Top-1 Accuracy:<br>Floating point</th>
+ <th>Top-1 Accuracy:<br>Fixed point: 8 bit weights and activations</th>
+ </tr>
+ <tr><td>Mobilenet-v1-128-0.25</td><td>0.415</td><td>0.399</td></tr>
+ <tr><td>Mobilenet-v1-128-0.5</td><td>0.563</td><td>0.549</td></tr>
+ <tr><td>Mobilenet-v1-128-0.75</td><td>0.621</td><td>0.598</td></tr>
+ <tr><td>Mobilenet-v1-128-1</td><td>0.652</td><td>0.64</td></tr>
+ <tr><td>Mobilenet-v1-160-0.25</td><td>0.455</td><td>0.435</td></tr>
+ <tr><td>Mobilenet-v1-160-0.5</td><td>0.591</td><td>0.577</td></tr>
+ <tr><td>Mobilenet-v1-160-0.75</td><td>0.653</td><td>0.639</td></tr>
+ <tr><td>Mobilenet-v1-160-1</td><td>0.68</td><td>0.673</td></tr>
+ <tr><td>Mobilenet-v1-192-0.25</td><td>0.477</td><td>0.458</td></tr>
+ <tr><td>Mobilenet-v1-192-0.5</td><td>0.617</td><td>0.604</td></tr>
+ <tr><td>Mobilenet-v1-192-0.75</td><td>0.672</td><td>0.662</td></tr>
+ <tr><td>Mobilenet-v1-192-1</td><td>0.7</td><td>0.69</td></tr>
+ <tr><td>Mobilenet-v1-224-0.25</td><td>0.498</td><td>0.482</td></tr>
+ <tr><td>Mobilenet-v1-224-0.5</td><td>0.633</td><td>0.622</td></tr>
+ <tr><td>Mobilenet-v1-224-0.75</td><td>0.684</td><td>0.679</td></tr>
+ <tr><td>Mobilenet-v1-224-1</td><td>0.709</td><td>0.697</td></tr>
+ <tr><td>Mobilenet-v2-224-1</td><td>0.718</td><td>0.708</td></tr>
+ <tr><td>Inception_v3</td><td>0.78</td><td>0.775</td></tr>
+ </table>
+ <figcaption>
+ <b>Table 1</b>: Top-1 accuracy of floating point and fully quantized CNNs on Imagenet Validation dataset.
+ </figcaption>
+</figure>
+
+Our pre-trained models are available in the
+<a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/models.md#image-classification-quantized-models" class="external">TensorFlow Lite model repository</a>. The code used to generate
+these models <a href="https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1_train.py" class="external">is available</a>.
+
+
+
These rewrites are an active area of research and experimentation, so the
rewrites and quantized training will likely not work across all models, though
we hope to work towards generalizing these techniques.
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index b27117dd48..e6c04bcf55 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -34,10 +34,10 @@ SKIPPED_PREFIXES = (
'ScalarSummary')
# Valid activation ops for quantization end points.
-_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity']
+_ACTIVATION_OP_SUFFIXES = ['Relu6', 'Relu', 'Identity']
# Regular expression for recognizing nodes that are part of batch norm group.
-_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm')
+_BATCHNORM_RE = re.compile(r'^(.*)BatchNorm/batchnorm')
def BatchNormGroups(graph):
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 2b26302f8a..a3ce041cea 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -13,21 +13,26 @@
# limitations under the License.
# ==============================================================================
"""Tests for common utilities in this package."""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
+from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
+batch_norm = layers.batch_norm
+conv2d = layers.conv2d
+
class CommonTest(test_util.TensorFlowTestCase):
@@ -87,6 +92,56 @@ class CommonTest(test_util.TensorFlowTestCase):
for i in inputs:
self.assertIn(i, op.inputs)
+ def testBatchNormScope(self):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ g = ops.Graph()
+ with g.as_default():
+ inputs = array_ops.zeros((batch_size, height, width, depth))
+ stride = 1
+ out_depth = 32
+ scope = ''
+ node = conv2d(
+ inputs,
+ out_depth, [2, 2],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=None,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(False),
+ scope=scope)
+
+ node = nn_ops.relu(node, name='Relu6')
+ bn_list = common.BatchNormGroups(g)
+ with open('/tmp/common_test.pbtxt', 'w') as f:
+ f.write(str(g.as_graph_def()))
+
+ # Exactly one batch norm layer with empty scope should be found
+ self.assertEqual(len(bn_list), 1)
+ self.assertEqual(bn_list[0], '')
+
+ def _BatchNormParams(self, fused=False, force_updates=False):
+ params = {
+ 'center': True,
+ 'scale': True,
+ 'decay': 1.0 - 0.003,
+ 'fused': fused
+ }
+ return params
+
+ def _WeightInit(self, stddev):
+ """Returns a truncated normal variable initializer.
+
+ Function is defined purely to shorten the name so that it stops wrapping.
+
+ Args:
+ stddev: Standard deviation of normal variable.
+
+ Returns:
+ An initializer that initializes with a truncated normal variable.
+ """
+ return init_ops.truncated_normal_initializer(stddev=stddev, seed=1234)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index 2971b28f45..e5790a6e13 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -95,8 +95,7 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
_ComputeBatchNormCorrections(
context='',
match=match,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- fused_batch_norm=True))
+ freeze_batch_norm_delay=freeze_batch_norm_delay))
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
@@ -296,8 +295,7 @@ def _FindFusedBatchNorms(graph):
batch_to_space_op=batch_to_space_op)
-def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
- fused_batch_norm):
+def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay):
"""Computes batch norm correction params.
Before batch normalization is frozen:
@@ -327,14 +325,14 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
computation.
freeze_batch_norm_delay: Delay in steps at which computation switches
from regular batch norm to frozen mean and variance.
- fused_batch_norm: Bool, true if fused batch norm is used.
+
Returns:
A tuple of correction_scale, correction_recip, correction_offset
"""
g = ops.get_default_graph()
- prefix = '' if not context else context + '/'
+ prefix = '' if not context else context
with g.name_scope(prefix + 'batch_norm_correction'):
recip_sigma_mv = math_ops.rsqrt(
match.moving_variance_tensor + match.batch_epsilon)
@@ -495,8 +493,23 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
# Treat consumer ops in bypass modules differently since they have Add
# operations instead of Relu* above.
- add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
- add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
+ # Changes to make sure that the correct scope is selected for the bypass add
+ # The rule here is that if the scope is of the form: str1/str2 for the
+ # batch norm,
+ # the bypass add is at scope str1. If bn is of scope just str1, then the
+ # bypass add is at scope ''.
+ # If there is no batch norm, then there is no bypass add.
+ add_bypass_ctx = ''
+ if bn:
+ try:
+ add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
+ except AttributeError:
+ add_bypass_ctx = ''
+
+ if add_bypass_ctx:
+ add_bypass_ctx = add_bypass_ctx + '/'
+
+ add_bypass = graph.get_operation_by_name(add_bypass_ctx + 'Add')
nodes_modified_count = common.RerouteTensor(
folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
@@ -505,8 +518,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
def _IsValidUnfusedBatchNorm(graph, context):
"""Checks that the output of the unfused batch norm has consumers."""
- add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/add_1')
+ add_shift = graph.get_operation_by_name(context +
+ 'BatchNorm/batchnorm_1/add_1')
# Ensure that the output tensor of batch norm has consumers, otherwise this
# is a dangling node and not a match.
return bool(add_shift.outputs[0].consumers())
@@ -538,7 +551,8 @@ def _FindMatchingTensor(graph, match_pattern, scope):
if op.name.endswith(match_pattern):
split_name = op.name.split('/')
num_matches = len(set(split_name) & split_context)
- if num_matches > 0:
+
+ if num_matches > 0 or not scope:
match_dict[op.name] = num_matches
# match_dict contains matching op names from graph with values being
# number of matches to scope. We pick the key with the most matches
@@ -597,21 +611,21 @@ def _GetBatchNormParams(graph, context, has_scaling):
# op.name = MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
# will have 2 matches,scope with a different conv layer will have one match.
- op_suffix_mean = '/BatchNorm/moments/Squeeze'
- op_suffix_variance = '/BatchNorm/moments/Squeeze_1'
- op_suffix_epsilon = '/BatchNorm/batchnorm_1/add/y'
- op_suffix_bn_decay_mean = '/BatchNorm/AssignMovingAvg/decay'
- op_suffix_bn_decay_var = '/BatchNorm/AssignMovingAvg_1/decay'
+ op_suffix_mean = 'BatchNorm/moments/Squeeze'
+ op_suffix_variance = 'BatchNorm/moments/Squeeze_1'
+ op_suffix_epsilon = 'BatchNorm/batchnorm_1/add/y'
+ op_suffix_bn_decay_mean = 'BatchNorm/AssignMovingAvg/decay'
+ op_suffix_bn_decay_var = 'BatchNorm/AssignMovingAvg_1/decay'
if variable_scope.get_variable_scope().use_resource:
- op_suffix_gamma = '/BatchNorm/gamma/Read/ReadVariableOp'
+ op_suffix_gamma = 'BatchNorm/gamma/Read/ReadVariableOp'
op_suffix_moving_variance = (
- '/BatchNorm/moving_variance/Read/ReadVariableOp')
- op_suffix_moving_mean = ('/BatchNorm/moving_mean/Read/ReadVariableOp')
+ 'BatchNorm/moving_variance/Read/ReadVariableOp')
+ op_suffix_moving_mean = ('BatchNorm/moving_mean/Read/ReadVariableOp')
else:
- op_suffix_gamma = '/BatchNorm/gamma'
- op_suffix_moving_variance = '/BatchNorm/moving_variance/read'
- op_suffix_moving_mean = '/BatchNorm/moving_mean/read'
+ op_suffix_gamma = 'BatchNorm/gamma'
+ op_suffix_moving_variance = 'BatchNorm/moving_variance/read'
+ op_suffix_moving_mean = 'BatchNorm/moving_mean/read'
# Parse through list of ops to find relevant ops
batch_mean_tensor = _FindMatchingTensor(graph, op_suffix_mean, context)
@@ -679,8 +693,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
the folded graph (add_fold).
"""
mul_scale_name = 'mul_1' if has_scaling else 'mul'
- mul_scale = graph.get_operation_by_name(context +
- '/BatchNorm/batchnorm_1/' +
+ mul_scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
mul_scale_name)
op_below = mul_scale.inputs[0].op
# Skip over the BatchToSpace operation in the case of atrous convolutions.
@@ -697,8 +710,7 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
_ComputeBatchNormCorrections(
context=context,
match=match,
- freeze_batch_norm_delay=freeze_batch_norm_delay,
- fused_batch_norm=False))
+ freeze_batch_norm_delay=freeze_batch_norm_delay))
# Special handling for weights of depthwise convolution.
if op_below.type == 'DepthwiseConv2dNative':
new_shape = [
@@ -706,27 +718,27 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
weights.get_shape().as_list()[3]
]
scale_name = 'mul' if has_scaling else 'Rsqrt'
- scale = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/' + scale_name)
+ scale = graph.get_operation_by_name(context + 'BatchNorm/batchnorm_1/' +
+ scale_name)
scale = array_ops.reshape(scale.outputs[0], new_shape,
- context + '/scale_reshape')
+ context + 'scale_reshape')
if correction_scale is not None:
correction_scale = array_ops.reshape(correction_scale, new_shape,
- context + '/correction_reshape')
+ context + 'correction_reshape')
with ops.device(mul_scale.device):
weights = math_ops.multiply(correction_scale, weights,
- context + '/correction_mult')
+ context + 'correction_mult')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights),
- (1, scale)])
+ mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights),
+ (1, scale)])
elif op_below.type in ['Conv2D', 'MatMul']:
if correction_scale is not None:
with ops.device(mul_scale.device):
weights = math_ops.multiply(correction_scale, weights,
- context + '/correction_mult')
- mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)])
+ context + 'correction_mult')
+ mul_fold = _CloneOp(mul_scale, context + 'mul_fold', [(0, weights)])
else:
raise ValueError('Cannot handle operation of type: %s' % op_below.type)
_AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0])
@@ -734,8 +746,8 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold',
[(1, mul_fold.outputs[0])])
- add_shift = graph.get_operation_by_name(
- context + '/BatchNorm/batchnorm_1/add_1')
+ add_shift = graph.get_operation_by_name(context +
+ 'BatchNorm/batchnorm_1/add_1')
corrected_output = conv_or_fc_folded.outputs[0]
# Copy the batch to space operation if we have a atrous convolution.
@@ -748,10 +760,10 @@ def _CreateFoldedOp(graph, context, has_scaling, freeze_batch_norm_delay,
if correction_offset is not None:
with ops.device(conv_or_fc_folded.device):
corrected_output = math_ops.multiply(correction_recip, corrected_output,
- context + '/post_conv_mul')
+ context + 'post_conv_mul')
corrected_output = math_ops.add(corrected_output, (correction_offset),
- context + '/correction_add')
- add_fold = _CloneOp(add_shift, context + '/add_fold', [(0, corrected_output)])
+ context + 'correction_add')
+ add_fold = _CloneOp(add_shift, context + 'add_fold', [(0, corrected_output)])
_AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0])
return add_shift, add_fold
@@ -930,7 +942,7 @@ def _HasScaling(graph, input_to_ops_map, bn):
Returns:
A boolean indicating whether this batch norm layer has scaling enabled.
"""
- rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm_1/Rsqrt')
+ rsqrt_op = graph.get_operation_by_name(bn + 'BatchNorm/batchnorm_1/Rsqrt')
rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op)
return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index e88db0acd5..5e63d33db8 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -97,8 +97,11 @@ def Quantize(graph,
layer_match.activation_op)
add_context = context
if layer_match.bypass_op:
- add_context = re.search(r'^(.*)/([^/]+)', context).group(1)
-
+ pattern_match_result = re.search(r'^(.*)/([^/]+)', context)
+ if pattern_match_result is not None:
+ add_context = pattern_match_result.group(1)
+ else:
+ add_context = ''
# If `scope` is given, only quantize it if the producer of weights
# (usually it's the layer op) is in the right scope.
_InsertQuantOp(
@@ -156,8 +159,12 @@ def Quantize(graph,
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
- post_activation_bypass_context = re.search(
- r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
+ pattern_match_result = re.search(
+ r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name)
+ if pattern_match_result is not None:
+ post_activation_bypass_context = pattern_match_result.group(1)
+ else:
+ post_activation_bypass_context = ''
# If `scope` is given, only quantize it if the producer is in the right
# scope.
# Make sure the op following this isn't an activation. In which case, we
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 31a2955ddb..f6bf57a789 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -58,85 +58,102 @@ class QuantizeTest(test_util.TensorFlowTestCase):
]
for params in parameters_list:
# Test everything with resource variables and normal variables.
- test_fn(params[0], params[1], params[2], params[3], False)
- test_fn(params[0], params[1], params[2], params[3], True)
+ test_fn(params[0], params[1], params[2], params[3], False, None)
+ test_fn(params[0], params[1], params[2], params[3], True, None)
+ # Test with both empty scope and an example scope
+ test_fn(params[0], params[1], params[2], params[3], False, 'test')
+ test_fn(params[0], params[1], params[2], params[3], True, 'test')
def _AssertCorrectQuantizedGraphWithoutBatchNorm(
self, graph, scope, layer, activation_op_name, with_bypass, delay,
use_resource):
quantization_node_name = 'FakeQuantWithMinMaxVars'
- weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' +
- quantization_node_name)
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ delim = '/' if conv_scope else ''
+
+ if scope:
+ scope = scope + '/'
+ weights_quant = graph.get_operation_by_name(
+ conv_scope + delim + 'weights_quant/' + quantization_node_name)
self.assertEqual(weights_quant.type, quantization_node_name)
# Assemble the expected inputs.
if use_resource:
expected_inputs = [
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
if layer == 'DepthwiseConv2dNative':
- expected_inputs.append(scope + '/depthwise/ReadVariableOp')
+ expected_inputs.append(conv_scope + delim + 'depthwise/ReadVariableOp')
else:
- expected_inputs.append(scope + '/' + layer + '/ReadVariableOp')
+ expected_inputs.append(conv_scope + delim + layer + '/ReadVariableOp')
else:
expected_inputs = [
- scope + '/weights_quant/AssignMinLast',
- scope + '/weights_quant/AssignMaxLast',
+ conv_scope + delim + 'weights_quant/AssignMinLast',
+ conv_scope + delim + 'weights_quant/AssignMaxLast',
]
if layer == 'DepthwiseConv2dNative':
- expected_inputs.append(scope + '/depthwise_weights/read')
+ expected_inputs.append(conv_scope + delim + 'depthwise_weights/read')
else:
- expected_inputs.append(scope + '/weights/read')
+ expected_inputs.append(conv_scope + delim + 'weights/read')
self._AssertInputOpsAre(weights_quant, expected_inputs)
if delay and delay > 0:
- output_op_name = scope + '/weights_quant/delayed_quant/Switch_1'
+ output_op_name = (
+ conv_scope + delim + 'weights_quant/delayed_quant/Switch_1')
else:
if layer == 'DepthwiseConv2dNative':
- output_op_name = scope + '/depthwise'
+ output_op_name = conv_scope + delim + 'depthwise'
else:
- output_op_name = scope + '/' + layer
+ output_op_name = conv_scope + delim + layer
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
- conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' +
- quantization_node_name)
+ conv_quant = graph.get_operation_by_name(
+ conv_scope + delim + 'conv_quant/' + quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
- scope + '/BiasAdd',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim + 'BiasAdd',
]
else:
expected_inputs = [
- scope + '/conv_quant/AssignMinEma',
- scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
+ conv_scope + delim + 'conv_quant/AssignMinEma',
+ conv_scope + delim + 'conv_quant/AssignMaxEma',
+ conv_scope + delim + 'BiasAdd'
]
self._AssertInputOpsAre(conv_quant, expected_inputs)
- output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
- if delay else 'test/Add')
+
+ output_op_name = (
+ conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+ if delay else scope + 'Add')
self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
- act_quant = graph.get_operation_by_name('test/act_quant/' +
+ act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
- 'test/' + activation_op_name,
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ scope + activation_op_name,
]
else:
expected_inputs = [
- 'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
- 'test/' + activation_op_name
+ scope + 'act_quant/AssignMinEma', scope + 'act_quant/AssignMaxEma',
+ scope + activation_op_name
]
self._AssertInputOpsAre(act_quant, expected_inputs)
- output_op_name = ('test/act_quant/delayed_quant/Switch_1'
- if delay else 'control_dependency')
+ output_op_name = (
+ scope + 'act_quant/delayed_quant/Switch_1'
+ if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
self._AssertIdempotent(graph)
@@ -145,7 +162,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_Conv2dWithoutBatchNorm)
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
- with_bypass, delay, use_resource):
+ with_bypass, delay, use_resource,
+ scope):
"""Tests quantization: inputs -> Conv2d no batch norm -> Activation.
Args:
@@ -156,6 +174,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -165,7 +184,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = conv2d(
inputs,
out_depth, [5, 5],
@@ -173,16 +194,19 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
quantize.Quantize(graph, True, quant_delay=delay)
+ if conv_scope is None:
+ conv_scope = ''
+
self._AssertCorrectQuantizedGraphWithoutBatchNorm(
graph, scope, 'Conv2D', activation_op_name, with_bypass, delay,
use_resource)
@@ -192,7 +216,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_FCWithoutBatchNorm)
def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
- with_bypass, delay, use_resource):
+ with_bypass, delay, use_resource, scope):
"""Tests quantization: inputs -> FC no batch norm -> Activation.
Args:
@@ -203,6 +227,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -211,16 +236,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ fc_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = fully_connected(
inputs,
out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=activation_fn,
- scope=scope)
+ scope=fc_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -235,7 +262,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
- self, activation, activation_op_name, with_bypass, delay, use_resource):
+ self, activation, activation_op_name, with_bypass, delay, use_resource,
+ scope):
"""Tests quantization: inputs -> DWConv2d no batch norm -> Activation.
Args:
@@ -246,6 +274,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -254,7 +283,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [5, 5],
@@ -263,10 +295,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -280,8 +312,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_AtrousConvWithoutBatchNorm)
- def _TestQuantize_AtrousConvWithoutBatchNorm(
- self, activation, activation_op_name, with_bypass, delay, use_resource):
+ def _TestQuantize_AtrousConvWithoutBatchNorm(self, activation,
+ activation_op_name, with_bypass,
+ delay, use_resource, scope):
"""Tests quantization: inputs -> atrous conv no batch norm -> Activation.
Args:
@@ -292,6 +325,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -300,7 +334,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
dilation_rate = 2
activation_fn = None if with_bypass else activation
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [3, 3],
@@ -309,10 +346,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
- scope=scope)
+ scope=conv_scope)
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
@@ -353,78 +390,96 @@ class QuantizeTest(test_util.TensorFlowTestCase):
]
for params in parameters_list:
# Test everything with resource variables and normal variables.
- test_fn(params[0], params[1], params[2], params[3], params[4], False)
- test_fn(params[0], params[1], params[2], params[3], params[4], True)
+ test_fn(params[0], params[1], params[2], params[3], params[4], False,
+ None)
+ test_fn(params[0], params[1], params[2], params[3], params[4], True, None)
+ test_fn(params[0], params[1], params[2], params[3], params[4], False,
+ 'test')
+ test_fn(params[0], params[1], params[2], params[3], params[4], True,
+ 'test')
def _AssertCorrectQuantizedGraphWithBatchNorm(self, graph, scope, layer,
activation_op_name, with_bypass,
delay, use_resource):
quantization_node_name = 'FakeQuantWithMinMaxVars'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ delim = '/' if conv_scope else ''
+
+ if scope:
+ scope = scope + '/'
+
weights_quant = graph.get_operation_by_name(
- scope + '/weights_quant/' + quantization_node_name)
+ conv_scope + delim + 'weights_quant/' + quantization_node_name)
+
self.assertEqual(weights_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'weights_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- scope + '/weights_quant/' + 'AssignMinLast',
- scope + '/weights_quant/' + 'AssignMaxLast'
+ conv_scope + delim + 'weights_quant/' + 'AssignMinLast',
+ conv_scope + delim + 'weights_quant/' + 'AssignMaxLast'
]
- expected_inputs.append(scope + '/mul_fold')
+ expected_inputs.append(conv_scope + delim + 'mul_fold')
self._AssertInputOpsAre(weights_quant, expected_inputs)
if layer == 'DepthwiseConv2dNative':
- output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
- if delay else '/depthwise_Fold')
+ output_op_name = conv_scope + delim + (
+ 'weights_quant/delayed_quant/Switch_1' if delay else 'depthwise_Fold')
else:
- output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1'
- if delay else '/' + layer + '_Fold')
+ output_op_name = conv_scope + delim + (
+ 'weights_quant/delayed_quant/Switch_1' if delay else layer + '_Fold')
self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name])
if with_bypass:
conv_quant = graph.get_operation_by_name(
- scope + '/conv_quant/' + quantization_node_name)
+ conv_scope + delim + 'conv_quant/' + quantization_node_name)
self.assertEqual(conv_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- scope + '/conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ conv_scope + delim +
+ 'conv_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- scope + '/conv_quant/AssignMinEma',
- scope + '/conv_quant/AssignMaxEma',
+ conv_scope + delim + 'conv_quant/AssignMinEma',
+ conv_scope + delim + 'conv_quant/AssignMaxEma',
]
- expected_inputs.append(scope + '/add_fold')
+ expected_inputs.append(conv_scope + delim + 'add_fold')
self._AssertInputOpsAre(conv_quant, expected_inputs)
output_op_name = (
- scope + '/conv_quant/delayed_quant/Switch_1' if delay else 'test/Add')
+ conv_scope + delim + 'conv_quant/delayed_quant/Switch_1'
+ if delay else scope + 'Add')
self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name])
- act_quant = graph.get_operation_by_name(
- 'test/act_quant/' + quantization_node_name)
+ act_quant = graph.get_operation_by_name(scope + 'act_quant/' +
+ quantization_node_name)
self.assertEqual(act_quant.type, quantization_node_name)
if use_resource:
expected_inputs = [
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
- 'test/act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp',
+ scope + 'act_quant/FakeQuantWithMinMaxVars/ReadVariableOp_1',
]
else:
expected_inputs = [
- 'test/act_quant/AssignMinEma',
- 'test/act_quant/AssignMaxEma',
+ scope + 'act_quant/AssignMinEma',
+ scope + 'act_quant/AssignMaxEma',
]
- expected_inputs.append('test/' + activation_op_name)
+ expected_inputs.append(scope + activation_op_name)
self._AssertInputOpsAre(act_quant, expected_inputs)
- output_op_name = ('test/act_quant/delayed_quant/Switch_1'
- if delay else 'control_dependency')
+ output_op_name = (
+ scope + 'act_quant/delayed_quant/Switch_1'
+ if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
self._AssertIdempotent(graph)
@@ -433,7 +488,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
- use_resource):
+ use_resource, scope):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args:
@@ -445,6 +500,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -453,7 +509,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = conv2d(
inputs,
out_depth, [5, 5],
@@ -463,13 +521,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -487,7 +545,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, fused_batch_norm,
- use_resource):
+ use_resource, scope):
"""Tests quantization: inputs -> FC with batch norm -> Activation.
Args:
@@ -499,6 +557,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -506,7 +565,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, depth = 5, 256
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = fully_connected(
inputs,
out_depth,
@@ -514,13 +575,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -540,7 +601,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
- fused_batch_norm, use_resource):
+ fused_batch_norm, use_resource, scope):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args:
@@ -552,6 +613,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -559,7 +621,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
node = separable_conv2d(
inputs,
None, [5, 5],
@@ -570,13 +634,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -595,7 +659,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
def _TestQuantize_AtrousConvWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay,
- fused_batch_norm, use_resource):
+ fused_batch_norm, use_resource, scope):
"""Tests quantization: inputs -> atrous conv with batch norm -> Activation.
Args:
@@ -607,6 +671,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_resource: Bool, when true uses resource variables.
+ scope: String, specifies top level scope for the graph
"""
graph = ops.Graph()
with graph.as_default():
@@ -614,7 +679,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
dilation_rate = 2
- scope = 'test/test2' if with_bypass else 'test'
+ conv_scope = self._GetConvScope(scope, with_bypass)
+ scope = '' if scope is None else scope
+ delim = '/' if scope else ''
+
node = separable_conv2d(
inputs,
None, [3, 3],
@@ -625,13 +693,13 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
- scope=scope)
+ scope=conv_scope)
# Manually add a bypass (optional) and an activation.
if with_bypass:
- node = math_ops.add(inputs, node, name='test/Add')
+ node = math_ops.add(inputs, node, name=scope + delim + 'Add')
- node = activation(node, name='test/' + activation_op_name)
+ node = activation(node, name=scope + delim + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
@@ -718,6 +786,18 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with open('/tmp/bn_quant_test.pbtxt', 'w') as f:
f.write(str(graph.as_graph_def()))
+ def _GetConvScope(self, scope, with_bypass):
+ if scope is None:
+ scope = ''
+ delim = '/' if scope else ''
+
+ if with_bypass:
+ conv_scope = scope + delim + 'test2'
+ else:
+ conv_scope = scope
+
+ return conv_scope
+
def _BatchNormParams(self, fused=False, force_updates=False):
params = {
'center': True,
diff --git a/tensorflow/contrib/rate/rate_test.py b/tensorflow/contrib/rate/rate_test.py
index 08908104f4..3dee163881 100644
--- a/tensorflow/contrib/rate/rate_test.py
+++ b/tensorflow/contrib/rate/rate_test.py
@@ -46,7 +46,7 @@ class RateTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
r_ = rate.Rate()
a = r_(array_ops.ones([1]), denominator=array_ops.ones([1]))
self.evaluate(variables.global_variables_initializer())
@@ -67,7 +67,7 @@ class RateTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testWhileLoop(self):
- with self.test_session():
+ with self.cached_session():
r_ = rate.Rate()
def body(value, denom, i, ret_rate):
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index c3db71359c..3abf7bd6da 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import copy
from tensorflow.contrib.recurrent.python.ops import recurrent
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -62,7 +61,7 @@ class _FunctionalRnnCell(object):
assert initial_state is not None
# TODO(drpng): Dtype needs to be configurable.
- input_dtypes = [dtypes.float32] + _GetDTypesFromStructure(initial_state)
+ input_dtypes = [seq_inputs.dtype] + _GetDTypesFromStructure(initial_state)
# See _index.
like_inputs_t = nest.map_structure(
lambda x: array_ops.stop_gradient(array_ops.gather(x, 0)), seq_inputs)
@@ -144,7 +143,10 @@ class _FunctionalRnnCell(object):
@property
def extended_initial_state(self):
if self._prepend_output:
- return [array_ops.zeros(self._output_shape), self._state_template]
+ return [array_ops.zeros(
+ self._output_shape,
+ dtype=_GetDTypesFromStructure(self._state_template)[0]),
+ self._state_template]
else:
# The base case, where the output is just the hidden state.
return self._state_template
@@ -185,7 +187,7 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
lengths = array_ops.tile(
array_ops.reshape(sequence_lengths, [-1, 1]), [1, max_time])
is_less = math_ops.cast(
- math_ops.less(output_time, lengths), dtype=dtypes.float32)
+ math_ops.less(output_time, lengths), dtype=tf_output.dtype)
keep_mask = array_ops.tile(
array_ops.expand_dims(is_less, -1),
[1, 1, vector_size])
@@ -217,7 +219,7 @@ def _PickFinalStateFromHistory(acc_state, sequence_length):
def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
- total_time, inputs_lengths):
+ total_time, inputs_lengths, is_reversed):
"""Post-process output of recurrent.
This function takes the accumulated extended state and extracts the requested
@@ -226,6 +228,8 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
When `inputs_lengths` has been set, it extracts the output from the
accumulated state. It also sets outputs past.
+ When `is_reversed` is true, the output will be reversed in this function.
+
It also sets the static shape information.
Args:
@@ -236,11 +240,12 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
func_cell: The functional wrapper around the cell.
total_time: A scalar integer tensor.
inputs_lengths: An integer tensor with one entry per input.
+ is_reversed: A boolean to indicate if the sequence is reversed.
Returns:
A tuple with the outputs at each time, and the final state.
"""
- if inputs_lengths is None:
+ if inputs_lengths is None or is_reversed:
flat_final_state = func_cell.MaybeRemoveOutputFromState(
nest.flatten(extended_final_state))
tf_state = nest.pack_sequence_as(func_cell.state_template, flat_final_state)
@@ -254,21 +259,28 @@ def _PostProcessOutput(extended_acc_state, extended_final_state, func_cell,
tf_state = _PickFinalStateFromHistory(acc_state, inputs_lengths)
output_from_state = func_cell.GetOutputFromState(extended_acc_state)
+ if is_reversed:
+ output_from_state = array_ops.reverse(output_from_state, [0])
tf_output = array_ops.transpose(output_from_state, [1, 0, 2])
tf_output.set_shape(
[func_cell.output_shape[0], total_time, func_cell.output_shape[1]])
if inputs_lengths is not None:
# Need set the outputs to zero.
tf_output = _ApplyLengthsToBatch(inputs_lengths, tf_output)
- # tf_output = array_ops.zeros([4, 3, 5])
_SetShapeFromTemplate(tf_state, func_cell.state_template)
return tf_output, tf_state
# pylint: disable=invalid-name
-def functional_rnn(cell, inputs, sequence_length=None,
- initial_state=None, dtype=None, time_major=False,
- scope=None, use_tpu=False):
+def functional_rnn(cell,
+ inputs,
+ sequence_length=None,
+ initial_state=None,
+ dtype=None,
+ time_major=False,
+ scope=None,
+ use_tpu=False,
+ reverse=False):
"""Same interface as `tf.nn.dynamic_rnn`."""
with variable_scope.variable_scope(scope or 'rnn'):
if not time_major:
@@ -283,33 +295,41 @@ def functional_rnn(cell, inputs, sequence_length=None,
max_length = math_ops.reduce_max(sequence_length)
else:
max_length = None
+ if reverse:
+ inputs = array_ops.reverse(inputs, [0])
extended_acc_state, extended_final_state = recurrent.Recurrent(
theta=func_cell.theta,
state0=func_cell.extended_initial_state,
inputs=inputs,
cell_fn=func_cell.cell_step,
max_input_length=max_length,
- use_tpu=use_tpu)
+ use_tpu=use_tpu,
+ aligned_end=reverse)
+
tf_output, tf_state = _PostProcessOutput(
- extended_acc_state, extended_final_state, func_cell,
- inputs_flat[0].shape[0], sequence_length)
+ extended_acc_state,
+ extended_final_state,
+ func_cell,
+ inputs_flat[0].shape[0],
+ sequence_length,
+ is_reversed=reverse)
if time_major:
tf_output = array_ops.transpose(tf_output, [1, 0, 2])
return tf_output, tf_state
-def bidirectional_functional_rnn(
- cell_fw,
- cell_bw,
- inputs,
- initial_state_fw=None,
- initial_state_bw=None,
- dtype=None,
- sequence_length=None,
- time_major=False,
- use_tpu=False,
- scope=None):
+def bidirectional_functional_rnn(cell_fw,
+ cell_bw,
+ inputs,
+ initial_state_fw=None,
+ initial_state_bw=None,
+ dtype=None,
+ sequence_length=None,
+ time_major=False,
+ use_tpu=False,
+ fast_reverse=False,
+ scope=None):
"""Creates a bidirectional recurrent neural network.
Performs fully dynamic unrolling of inputs in both directions. Built to be API
@@ -340,6 +360,10 @@ def bidirectional_functional_rnn(
use_tpu: Whether to enable TPU-compatible operation. If True, does not truly
reverse `inputs` in the backwards RNN. Once b/69305369 is fixed, we can
remove this flag.
+ fast_reverse: Whether to use fast tf.reverse to replace tf.reverse_sequence.
+ This is only possible when either all sequence lengths are the same inside
+ the batch, or when the cell function does not change the state on padded
+ input.
scope: An optional scope name for the dynamic RNN.
Returns:
@@ -388,17 +412,29 @@ def bidirectional_functional_rnn(
return array_ops.reverse(input_, axis=[seq_dim])
with variable_scope.variable_scope('bw') as bw_scope:
- inputs_reverse = _reverse(
- inputs, seq_lengths=sequence_length,
- seq_dim=time_dim, batch_dim=batch_dim)
- tmp, output_state_bw = functional_rnn(
- cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
- initial_state=initial_state_bw, dtype=dtype,
- time_major=time_major, scope=bw_scope, use_tpu=use_tpu)
-
- output_bw = _reverse(
- tmp, seq_lengths=sequence_length,
- seq_dim=time_dim, batch_dim=batch_dim)
+ if not fast_reverse:
+ inputs = _reverse(
+ inputs,
+ seq_lengths=sequence_length,
+ seq_dim=time_dim,
+ batch_dim=batch_dim)
+ output_bw, output_state_bw = functional_rnn(
+ cell=cell_bw,
+ inputs=inputs,
+ sequence_length=sequence_length,
+ initial_state=initial_state_bw,
+ dtype=dtype,
+ time_major=time_major,
+ scope=bw_scope,
+ use_tpu=use_tpu,
+ reverse=fast_reverse)
+
+ if not fast_reverse:
+ output_bw = _reverse(
+ output_bw,
+ seq_lengths=sequence_length,
+ seq_dim=time_dim,
+ batch_dim=batch_dim)
outputs = (output_fw, output_bw)
output_states = (output_state_fw, output_state_bw)
diff --git a/tensorflow/contrib/recurrent/python/ops/recurrent.py b/tensorflow/contrib/recurrent/python/ops/recurrent.py
index 4f289e0c85..f51de755d8 100644
--- a/tensorflow/contrib/recurrent/python/ops/recurrent.py
+++ b/tensorflow/contrib/recurrent/python/ops/recurrent.py
@@ -274,8 +274,16 @@ def _ConvertNoneGradientToZeros(xs, dxs):
class _Recurrent(object):
"""A helper class to construct a recurrent neural net."""
- def __init__(self, cell_fn, cell_grad, theta, state0, inputs,
- max_input_length, extras, use_tpu):
+ def __init__(self,
+ cell_fn,
+ cell_grad,
+ theta,
+ state0,
+ inputs,
+ max_input_length,
+ extras,
+ use_tpu,
+ aligned_end=False):
"""RNN helper class.
Args:
@@ -294,6 +302,8 @@ class _Recurrent(object):
and shapes of this `extras`.
use_tpu: A boolean indicating whether the computation is mean to
run on a TPU.
+ aligned_end: A boolean indicating whether the sequence is aligned at
+ the end.
"""
self._theta = theta
self._state = state0
@@ -303,6 +313,7 @@ class _Recurrent(object):
self._cell_fn = cell_fn
self._cell_grad = cell_grad
self._extras = extras
+ self._aligned_end = aligned_end
# pylint: disable=unbalanced-tuple-unpacking
@@ -417,10 +428,11 @@ class _Recurrent(object):
acc_state = _EmptyAcc(slen_dim, state0)
acc_extras = _EmptyAcc(slen_dim, extras)
- dev_t = array_ops.constant(0, dtype=dev_t_type)
+ t = slen_dim - max_input_length if self._aligned_end else 0
+ dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t)
run = functional_ops.For(
- start=0,
- limit=max_input_length,
+ start=t,
+ limit=slen_dim if self._aligned_end else max_input_length,
delta=1,
inputs=[dev_t] + _Flatten(
[theta, state0, inputs, acc_state, acc_extras]),
@@ -551,13 +563,16 @@ class _Recurrent(object):
d_theta = _EmptyLike(theta)
d_inputs = _EmptyLike(inputs)
+ slen_dim = _SeqLenDim(inputs)
+
# Loop backwards. Note the loop's limit is open-ended, so goes through
# t=0.
- t = max_input_length - 1
+ t = slen_dim - 1 if self._aligned_end else max_input_length - 1
dev_t = math_ops.to_int32(t) if use_tpu else math_ops.to_int64(t)
+ limit = slen_dim - max_input_length - 1 if self._aligned_end else -1
run = functional_ops.For(
start=t,
- limit=-1,
+ limit=limit,
delta=-1,
inputs=[dev_t] + _Flatten([
theta, state0, inputs, acc_state, acc_extras, d_theta, d_state1,
@@ -641,7 +656,8 @@ def Recurrent(theta,
cell_grad=None,
extras=None,
max_input_length=None,
- use_tpu=False):
+ use_tpu=False,
+ aligned_end=False):
"""Compute a recurrent neural net.
Roughly, Recurrent() computes the following:
@@ -684,6 +700,8 @@ def Recurrent(theta,
truncate the computation if the inputs have been allocated to a
larger size. A scalar tensor.
use_tpu: whether or not we are on TPU.
+ aligned_end: A boolean indicating whether the sequence is aligned at
+ the end.
Returns:
accumulate_state and the final state.
@@ -717,4 +735,5 @@ def Recurrent(theta,
inputs=inputs,
max_input_length=max_input_length,
extras=extras,
- use_tpu=use_tpu).Compute()
+ use_tpu=use_tpu,
+ aligned_end=aligned_end).Compute()
diff --git a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
index 6253f96315..e30e7255fa 100644
--- a/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
+++ b/tensorflow/contrib/resampler/python/ops/resampler_ops_test.py
@@ -210,7 +210,7 @@ class ResamplerTest(test.TestCase):
# Input data shape is not defined over a 2D grid, i.e. its shape is not like
# (batch_size, data_height, data_width, data_channels).
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_depth,
data_channels)
data = np.zeros(data_shape)
@@ -225,7 +225,7 @@ class ResamplerTest(test.TestCase):
sess.run(outputs)
# Warp tensor must be at least a matrix, with shape [batch_size, 2].
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size,)
@@ -238,7 +238,7 @@ class ResamplerTest(test.TestCase):
sess.run(outputs)
# The batch size of the data and warp tensors must be the same.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size+1, warp_height, warp_width, 2)
@@ -252,7 +252,7 @@ class ResamplerTest(test.TestCase):
# The warp tensor must contain 2D coordinates, i.e. its shape last dimension
# must be 2.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data_shape = (batch_size, data_height, data_width, data_channels)
data = np.zeros(data_shape)
warp_shape = (batch_size, warp_height, warp_width, 3)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 1c23c28860..0d615923e0 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -49,7 +49,7 @@ class RpcOpTestBase(object):
return rpc_op.try_rpc(*args, protocol=self._protocol, **kwargs)
def testScalarHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors = self.rpc(
@@ -63,7 +63,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([2, 3, 4], response_message.values)
def testScalarHostPortTryRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = (
test_example_pb2.TestCase(values=[1, 2, 3]).SerializeToString())
response_tensors, status_code, status_message = self.try_rpc(
@@ -83,7 +83,7 @@ class RpcOpTestBase(object):
self.assertEqual(b'', status_message_values)
def testEmptyHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = []
response_tensors = self.rpc(
method=self.get_method_name('Increment'),
@@ -98,7 +98,7 @@ class RpcOpTestBase(object):
'/InvalidService.Increment',
self.get_method_name('InvalidMethodName')
]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(self.invalid_method_string):
sess.run(self.rpc(method=method, address=self._address, request=''))
@@ -111,7 +111,7 @@ class RpcOpTestBase(object):
def testInvalidAddress(self):
# This covers the case of address='' and address='localhost:293874293874'
address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
@@ -128,7 +128,7 @@ class RpcOpTestBase(object):
self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
response_tensors = self.rpc(
method=self.get_method_name('AlwaysFailWithInvalidArgument'),
address=self._address,
@@ -150,7 +150,7 @@ class RpcOpTestBase(object):
self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
def testSometimesFailingMethodWithManyRequests(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Fail hard by default.
response_tensors = self.rpc(
method=self.get_method_name('SometimesFailWithInvalidArgument'),
@@ -179,7 +179,7 @@ class RpcOpTestBase(object):
self.assertAllEqual(expected_message_values, status_message_values)
def testVecHostPortRpc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -197,7 +197,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortManyParallelRpcs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [
test_example_pb2.TestCase(
values=[i, i + 1, i + 2]).SerializeToString() for i in range(20)
@@ -219,7 +219,7 @@ class RpcOpTestBase(object):
self.assertAllEqual([i + 1, i + 2, i + 3], response_message.values)
def testVecHostPortRpcUsingEncodeAndDecodeProto(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = encode_proto_op.encode_proto(
message_type='tensorflow.contrib.rpc.TestCase',
field_names=['values'],
@@ -241,7 +241,7 @@ class RpcOpTestBase(object):
for i in range(20)], response_shape_values)
def testVecHostPortRpcCancelsUponSessionTimeOutWhenSleepingForever(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
@@ -254,7 +254,7 @@ class RpcOpTestBase(object):
sess.run(response_tensors, options=options)
def testVecHostPortRpcCancelsUponConfiguredTimeOutWhenSleepingForever(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
request_tensors = [''] * 25 # This will launch 25 RPC requests.
response_tensors = self.rpc(
method=self.get_method_name('SleepForever'),
@@ -265,7 +265,7 @@ class RpcOpTestBase(object):
sess.run(response_tensors)
def testTryRpcPropagatesDeadlineErrorWithSometimesTimingOutRequests(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
response_tensors, status_code, status_message = self.try_rpc(
method=self.get_method_name('SometimesSleepForever'),
timeout_in_ms=1000,
@@ -281,7 +281,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleAddressesSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
@@ -301,7 +301,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleMethodsSingleRequest(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
methods = flatten(
[[self.get_method_name('Increment'), 'InvalidMethodName']
for _ in range(10)])
@@ -319,7 +319,7 @@ class RpcOpTestBase(object):
def testTryRpcWithMultipleAddressesAndRequests(self):
flatten = lambda x: list(itertools.chain.from_iterable(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
addresses = flatten([[
self._address, 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
] for _ in range(10)])
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index b897224c6d..291ff83791 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -78,23 +78,6 @@ py_test(
],
)
-py_test(
- name = "signature_def_utils_test",
- size = "small",
- srcs = ["python/saved_model/signature_def_utils_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":saved_model_py",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python/saved_model:signature_def_utils",
- "//tensorflow/python/saved_model:utils",
- ],
-)
-
py_library(
name = "keras_saved_model",
srcs = ["python/saved_model/keras_saved_model.py"],
@@ -109,10 +92,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:saver",
"//tensorflow/python:util",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:export",
- "//tensorflow/python/estimator:keras",
- "//tensorflow/python/estimator:model_fn",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras:engine",
"//tensorflow/python/saved_model",
],
@@ -123,10 +103,12 @@ py_test(
size = "medium",
srcs = ["python/saved_model/keras_saved_model_test.py"],
srcs_version = "PY2AND3",
+ tags = ["notsan"],
deps = [
":keras_saved_model",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
index 074dc655ac..ac95e38011 100644
--- a/tensorflow/contrib/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -25,13 +25,11 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
-from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
# pylint: enable=unused-import,wildcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
- "get_signature_def_by_key",
"load_keras_model",
"save_keras_model"]
diff --git a/tensorflow/contrib/saved_model/python/saved_model/__init__.py b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
index e3b76bb6f3..fd3dc1d7aa 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/__init__.py
@@ -25,5 +25,4 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
index 12dd72a95b..060c504523 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -269,7 +269,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
def testSaveAndLoadSavedModelExport(
self, model_builder, uses_learning_phase, optimizer, train_before_export):
saved_model_path = self._save_model_dir()
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
input_arr = np.random.random((1, 3))
target_arr = np.random.random((1, 3))
diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py
deleted file mode 100644
index f521647999..0000000000
--- a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""SignatureDef utility functions implementation."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-def get_signature_def_by_key(meta_graph_def, signature_def_key):
- """Utility function to get a SignatureDef protocol buffer by its key.
-
- Args:
- meta_graph_def: MetaGraphDef protocol buffer with the SignatureDefMap to
- look up.
- signature_def_key: Key of the SignatureDef protocol buffer to find in the
- SignatureDefMap.
-
- Returns:
- A SignatureDef protocol buffer corresponding to the supplied key, if it
- exists.
-
- Raises:
- ValueError: If no entry corresponding to the supplied key is found in the
- SignatureDefMap of the MetaGraphDef.
- """
- if signature_def_key not in meta_graph_def.signature_def:
- raise ValueError("No SignatureDef with key '%s' found in MetaGraphDef." %
- signature_def_key)
- return meta_graph_def.signature_def[signature_def_key]
diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py
deleted file mode 100644
index d2e14f73e4..0000000000
--- a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for SignatureDef utils."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils as signature_def_contrib_utils
-from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-from tensorflow.python.saved_model import signature_constants
-from tensorflow.python.saved_model import signature_def_utils
-from tensorflow.python.saved_model import utils
-
-
-class SignatureDefUtilsTest(test.TestCase):
-
- def _add_to_signature_def_map(self, meta_graph_def, signature_def_map=None):
- if signature_def_map is not None:
- for key in signature_def_map:
- meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key])
-
- def _check_tensor_info(self, tensor_info_map, map_key, expected_tensor_name):
- actual_tensor_info = tensor_info_map[map_key]
- self.assertEqual(expected_tensor_name, actual_tensor_info.name)
-
- def testGetSignatureDefByKey(self):
- x = array_ops.placeholder(dtypes.float32, 1, name="x")
- x_tensor_info = utils.build_tensor_info(x)
-
- y = array_ops.placeholder(dtypes.float32, name="y")
- y_tensor_info = utils.build_tensor_info(y)
-
- foo_signature_def = signature_def_utils.build_signature_def({
- "foo-input": x_tensor_info
- }, {"foo-output": y_tensor_info}, "foo-method-name")
- bar_signature_def = signature_def_utils.build_signature_def({
- "bar-input": x_tensor_info
- }, {"bar-output": y_tensor_info}, "bar-method-name")
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- self._add_to_signature_def_map(
- meta_graph_def, {"foo": foo_signature_def,
- "bar": bar_signature_def})
-
- # Look up a key that does not exist in the SignatureDefMap.
- missing_key = "missing-key"
- with self.assertRaisesRegexp(
- ValueError,
- "No SignatureDef with key '%s' found in MetaGraphDef" % missing_key):
- signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, missing_key)
-
- # Look up the key, `foo` which exists in the SignatureDefMap.
- foo_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "foo")
- self.assertTrue("foo-method-name", foo_signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(foo_signature_def.inputs))
- self._check_tensor_info(foo_signature_def.inputs, "foo-input", "x:0")
-
- # Check outputs in signature def.
- self.assertEqual(1, len(foo_signature_def.outputs))
- self._check_tensor_info(foo_signature_def.outputs, "foo-output", "y:0")
-
- # Look up the key, `bar` which exists in the SignatureDefMap.
- bar_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "bar")
- self.assertTrue("bar-method-name", bar_signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(bar_signature_def.inputs))
- self._check_tensor_info(bar_signature_def.inputs, "bar-input", "x:0")
-
- # Check outputs in signature def.
- self.assertEqual(1, len(bar_signature_def.outputs))
- self._check_tensor_info(bar_signature_def.outputs, "bar-output", "y:0")
-
- def testGetSignatureDefByKeyRegression(self):
- input1 = constant_op.constant("a", name="input-1")
- output1 = constant_op.constant(7.2, name="output-1")
-
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- self._add_to_signature_def_map(meta_graph_def, {
- "my_regression":
- signature_def_utils.regression_signature_def(input1, output1)
- })
-
- # Look up the regression signature with the key used while saving.
- signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "my_regression")
-
- # Check the method name to match the constants regression method name.
- self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
- signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(signature_def.inputs))
- self._check_tensor_info(signature_def.inputs,
- signature_constants.REGRESS_INPUTS, "input-1:0")
-
- # Check outputs in signature def.
- self.assertEqual(1, len(signature_def.outputs))
- self._check_tensor_info(signature_def.outputs,
- signature_constants.REGRESS_OUTPUTS, "output-1:0")
-
- def testGetSignatureDefByKeyClassification(self):
- input1 = constant_op.constant("a", name="input-1")
- output1 = constant_op.constant("b", name="output-1")
- output2 = constant_op.constant(3.0, name="output-2")
-
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- self._add_to_signature_def_map(meta_graph_def, {
- "my_classification":
- signature_def_utils.classification_signature_def(
- input1, output1, output2)
- })
-
- # Look up the classification signature def with the key used while saving.
- signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "my_classification")
-
- # Check the method name to match the constants classification method name.
- self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
- signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(signature_def.inputs))
- self._check_tensor_info(signature_def.inputs,
- signature_constants.CLASSIFY_INPUTS, "input-1:0")
-
- # Check outputs in signature def.
- self.assertEqual(2, len(signature_def.outputs))
- self._check_tensor_info(signature_def.outputs,
- signature_constants.CLASSIFY_OUTPUT_CLASSES,
- "output-1:0")
- self._check_tensor_info(signature_def.outputs,
- signature_constants.CLASSIFY_OUTPUT_SCORES,
- "output-2:0")
-
- def testPredictionSignatureDef(self):
- input1 = constant_op.constant("a", name="input-1")
- input2 = constant_op.constant("b", name="input-2")
- output1 = constant_op.constant("c", name="output-1")
- output2 = constant_op.constant("d", name="output-2")
-
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- self._add_to_signature_def_map(meta_graph_def, {
- "my_prediction":
- signature_def_utils.predict_signature_def({
- "input-1": input1,
- "input-2": input2
- }, {"output-1": output1,
- "output-2": output2})
- })
-
- # Look up the prediction signature def with the key used while saving.
- signature_def = signature_def_contrib_utils.get_signature_def_by_key(
- meta_graph_def, "my_prediction")
- self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
- signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(2, len(signature_def.inputs))
- self._check_tensor_info(signature_def.inputs, "input-1", "input-1:0")
- self._check_tensor_info(signature_def.inputs, "input-2", "input-2:0")
-
- # Check outputs in signature def.
- self.assertEqual(2, len(signature_def.outputs))
- self._check_tensor_info(signature_def.outputs, "output-1", "output-1:0")
- self._check_tensor_info(signature_def.outputs, "output-2", "output-2:0")
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index f2c43f30d4..1f3b533de9 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -919,31 +919,28 @@ class AttentionWrapperTest(test.TestCase):
wrapper.BahdanauAttention, wrapper.LuongAttention)
expected_final_output = BasicDecoderOutput(
- rnn_output=ResultSummary(shape=(5, 3, 20),
- dtype=dtype('float32'),
- mean=0.11723966),
- sample_id=ResultSummary(shape=(5, 3),
- dtype=dtype('int32'),
- mean=9.2666666666666675))
+ rnn_output=ResultSummary(
+ shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11723966),
+ sample_id=ResultSummary(
+ shape=(5, 3), dtype=dtype('int32'), mean=7.266666666666667))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
- c=ResultSummary(shape=(5, 9),
- dtype=dtype('float32'),
- mean=-0.003545674),
- h=ResultSummary(shape=(5, 9),
- dtype=dtype('float32'),
- mean=-0.0018327223)),
- attention=ResultSummary(shape=(5, 20),
- dtype=dtype('float32'),
- mean=0.11728073),
+ c=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.003545674),
+ h=ResultSummary(
+ shape=(5, 9), dtype=dtype('float32'), mean=-0.0018327223)),
+ attention=ResultSummary(
+ shape=(5, 20), dtype=dtype('float32'), mean=0.11601614207),
time=3,
- alignments=(
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
+ alignments=(ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125)),
alignment_history=(),
- attention_state=(
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
- ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
+ attention_state=(ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125),
+ ResultSummary(
+ shape=(5, 8), dtype=dtype('float32'), mean=0.125)))
expected_final_alignment_history = (
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125),
ResultSummary(shape=(3, 5, 8), dtype=dtype('float32'), mean=0.125))
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index f5b6b1bde9..5e28e651c6 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -248,6 +248,7 @@ class TestBeamStep(test.TestCase):
self.vocab_size = 5
self.end_token = 0
self.length_penalty_weight = 0.6
+ self.coverage_penalty_weight = 0.0
def test_step(self):
dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
@@ -258,7 +259,8 @@ class TestBeamStep(test.TestCase):
lengths=constant_op.constant(
2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64),
finished=array_ops.zeros(
- [self.batch_size, self.beam_width], dtype=dtypes.bool))
+ [self.batch_size, self.beam_width], dtype=dtypes.bool),
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -281,7 +283,8 @@ class TestBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
@@ -313,7 +316,8 @@ class TestBeamStep(test.TestCase):
lengths=ops.convert_to_tensor(
[[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64),
finished=ops.convert_to_tensor(
- [[False, True, False], [False, False, True]], dtype=dtypes.bool))
+ [[False, True, False], [False, False, True]], dtype=dtypes.bool),
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -336,7 +340,8 @@ class TestBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, state_, log_probs_ = sess.run(
@@ -372,6 +377,7 @@ class TestLargeBeamStep(test.TestCase):
self.vocab_size = 5
self.end_token = 0
self.length_penalty_weight = 0.6
+ self.coverage_penalty_weight = 0.0
def test_step(self):
@@ -411,7 +417,8 @@ class TestLargeBeamStep(test.TestCase):
cell_state=dummy_cell_state,
log_probs=log_probs,
lengths=_lengths,
- finished=_finished)
+ finished=_finished,
+ accumulated_attention_probs=())
logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
0.0001)
@@ -434,7 +441,8 @@ class TestLargeBeamStep(test.TestCase):
batch_size=ops.convert_to_tensor(self.batch_size),
beam_width=self.beam_width,
end_token=self.end_token,
- length_penalty_weight=self.length_penalty_weight)
+ length_penalty_weight=self.length_penalty_weight,
+ coverage_penalty_weight=self.coverage_penalty_weight)
with self.cached_session() as sess:
outputs_, next_state_, _, _ = sess.run(
@@ -476,7 +484,9 @@ class BeamSearchDecoderTest(test.TestCase):
embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
cell = rnn_cell.LSTMCell(cell_depth)
initial_state = cell.zero_state(batch_size, dtypes.float32)
+ coverage_penalty_weight = 0.0
if has_attention:
+ coverage_penalty_weight = 0.2
inputs = array_ops.placeholder_with_default(
np.random.randn(batch_size, decoder_max_time, input_depth).astype(
np.float32),
@@ -508,7 +518,8 @@ class BeamSearchDecoderTest(test.TestCase):
initial_state=cell_state,
beam_width=beam_width,
output_layer=output_layer,
- length_penalty_weight=0.0)
+ length_penalty_weight=0.0,
+ coverage_penalty_weight=coverage_penalty_weight)
final_outputs, final_state, final_sequence_lengths = (
decoder.dynamic_decode(
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 74741a7bd6..605e3143fd 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import numpy as np
+from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
@@ -49,7 +50,8 @@ __all__ = [
class BeamSearchDecoderState(
collections.namedtuple("BeamSearchDecoderState",
- ("cell_state", "log_probs", "finished", "lengths"))):
+ ("cell_state", "log_probs", "finished", "lengths",
+ "accumulated_attention_probs"))):
pass
@@ -260,6 +262,10 @@ class BeamSearchDecoder(decoder.Decoder):
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
```
+
+ Meanwhile, with `AttentionWrapper`, coverage penalty is suggested to use
+ when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
+ the translation to cover all inputs.
"""
def __init__(self,
@@ -271,6 +277,7 @@ class BeamSearchDecoder(decoder.Decoder):
beam_width,
output_layer=None,
length_penalty_weight=0.0,
+ coverage_penalty_weight=0.0,
reorder_tensor_arrays=True):
"""Initialize the BeamSearchDecoder.
@@ -286,6 +293,8 @@ class BeamSearchDecoder(decoder.Decoder):
`tf.layers.Dense`. Optional layer to apply to the RNN output prior
to storing the result or sampling.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
reorder_tensor_arrays: If `True`, `TensorArray`s' elements within the cell
state will be reordered according to the beam search path. If the
`TensorArray` can be reordered, the stacked form will be returned.
@@ -326,6 +335,7 @@ class BeamSearchDecoder(decoder.Decoder):
self._batch_size = array_ops.size(start_tokens)
self._beam_width = beam_width
self._length_penalty_weight = length_penalty_weight
+ self._coverage_penalty_weight = coverage_penalty_weight
self._initial_cell_state = nest.map_structure(
self._maybe_split_batch_beams, initial_state, self._cell.state_size)
self._start_tokens = array_ops.tile(
@@ -411,13 +421,18 @@ class BeamSearchDecoder(decoder.Decoder):
on_value=ops.convert_to_tensor(0.0, dtype=dtype),
off_value=ops.convert_to_tensor(-np.Inf, dtype=dtype),
dtype=dtype)
+ init_attention_probs = get_attention_probs(
+ self._initial_cell_state, self._coverage_penalty_weight)
+ if init_attention_probs is None:
+ init_attention_probs = ()
initial_state = BeamSearchDecoderState(
cell_state=self._initial_cell_state,
log_probs=log_probs,
finished=finished,
lengths=array_ops.zeros(
- [self._batch_size, self._beam_width], dtype=dtypes.int64))
+ [self._batch_size, self._beam_width], dtype=dtypes.int64),
+ accumulated_attention_probs=init_attention_probs)
return (finished, start_inputs, initial_state)
@@ -631,6 +646,7 @@ class BeamSearchDecoder(decoder.Decoder):
beam_width = self._beam_width
end_token = self._end_token
length_penalty_weight = self._length_penalty_weight
+ coverage_penalty_weight = self._coverage_penalty_weight
with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
cell_state = state.cell_state
@@ -655,7 +671,8 @@ class BeamSearchDecoder(decoder.Decoder):
batch_size=batch_size,
beam_width=beam_width,
end_token=end_token,
- length_penalty_weight=length_penalty_weight)
+ length_penalty_weight=length_penalty_weight,
+ coverage_penalty_weight=coverage_penalty_weight)
finished = beam_search_state.finished
sample_ids = beam_search_output.predicted_ids
@@ -667,7 +684,8 @@ class BeamSearchDecoder(decoder.Decoder):
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
- beam_width, end_token, length_penalty_weight):
+ beam_width, end_token, length_penalty_weight,
+ coverage_penalty_weight):
"""Performs a single step of Beam Search Decoding.
Args:
@@ -684,6 +702,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
beam_width: Python int. The size of the beams.
end_token: The int32 end token.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
Returns:
A new beam state.
@@ -693,6 +713,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
# Calculate the current lengths of the predictions
prediction_lengths = beam_state.lengths
previously_finished = beam_state.finished
+ not_finished = math_ops.logical_not(previously_finished)
# Calculate the total log probs for the new hypotheses
# Final Shape: [batch_size, beam_width, vocab_size]
@@ -708,16 +729,29 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
on_value=np.int64(0),
off_value=np.int64(1),
dtype=dtypes.int64)
- add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
+ add_mask = math_ops.to_int64(not_finished)
lengths_to_add *= array_ops.expand_dims(add_mask, 2)
new_prediction_lengths = (
lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
+ # Calculate the accumulated attention probabilities if coverage penalty is
+ # enabled.
+ accumulated_attention_probs = None
+ attention_probs = get_attention_probs(
+ next_cell_state, coverage_penalty_weight)
+ if attention_probs is not None:
+ attention_probs *= array_ops.expand_dims(math_ops.to_float(not_finished), 2)
+ accumulated_attention_probs = (
+ beam_state.accumulated_attention_probs + attention_probs)
+
# Calculate the scores for each beam
scores = _get_scores(
log_probs=total_probs,
sequence_lengths=new_prediction_lengths,
- length_penalty_weight=length_penalty_weight)
+ length_penalty_weight=length_penalty_weight,
+ coverage_penalty_weight=coverage_penalty_weight,
+ finished=previously_finished,
+ accumulated_attention_probs=accumulated_attention_probs)
time = ops.convert_to_tensor(time, name="time")
# During the first time step we only consider the initial beam
@@ -775,6 +809,15 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
range_size=beam_width,
gather_shape=[-1])
next_prediction_len += lengths_to_add
+ next_accumulated_attention_probs = ()
+ if accumulated_attention_probs is not None:
+ next_accumulated_attention_probs = _tensor_gather_helper(
+ gather_indices=next_beam_ids,
+ gather_from=accumulated_attention_probs,
+ batch_size=batch_size,
+ range_size=beam_width,
+ gather_shape=[batch_size * beam_width, -1],
+ name="next_accumulated_attention_probs")
# Pick out the cell_states according to the next_beam_ids. We use a
# different gather_shape here because the cell_state tensors, i.e.
@@ -795,7 +838,8 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
cell_state=next_cell_state,
log_probs=next_beam_probs,
lengths=next_prediction_len,
- finished=next_finished)
+ finished=next_finished,
+ accumulated_attention_probs=next_accumulated_attention_probs)
output = BeamSearchDecoderOutput(
scores=next_beam_scores,
@@ -805,7 +849,53 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
return output, next_state
-def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
+def get_attention_probs(next_cell_state, coverage_penalty_weight):
+ """Get attention probabilities from the cell state.
+
+ Args:
+ next_cell_state: The next state from the cell, e.g. an instance of
+ AttentionWrapperState if the cell is attentional.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
+
+ Returns:
+ The attention probabilities with shape `[batch_size, beam_width, max_time]`
+ if coverage penalty is enabled. Otherwise, returns None.
+
+ Raises:
+ ValueError: If no cell is attentional but coverage penalty is enabled.
+ """
+ if coverage_penalty_weight == 0.0:
+ return None
+
+ # Attention probabilities of each attention layer. Each with shape
+ # `[batch_size, beam_width, max_time]`.
+ probs_per_attn_layer = []
+ if isinstance(next_cell_state, attention_wrapper.AttentionWrapperState):
+ probs_per_attn_layer = [attention_probs_from_attn_state(next_cell_state)]
+ elif isinstance(next_cell_state, tuple):
+ for state in next_cell_state:
+ if isinstance(state, attention_wrapper.AttentionWrapperState):
+ probs_per_attn_layer.append(attention_probs_from_attn_state(state))
+
+ if not probs_per_attn_layer:
+ raise ValueError(
+ "coverage_penalty_weight must be 0.0 if no cell is attentional.")
+
+ if len(probs_per_attn_layer) == 1:
+ attention_probs = probs_per_attn_layer[0]
+ else:
+ # Calculate the average attention probabilities from all attention layers.
+ attention_probs = [
+ array_ops.expand_dims(prob, -1) for prob in probs_per_attn_layer]
+ attention_probs = array_ops.concat(attention_probs, -1)
+ attention_probs = math_ops.reduce_mean(attention_probs, -1)
+
+ return attention_probs
+
+
+def _get_scores(log_probs, sequence_lengths, length_penalty_weight,
+ coverage_penalty_weight, finished, accumulated_attention_probs):
"""Calculates scores for beam search hypotheses.
Args:
@@ -813,13 +903,78 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
`[batch_size, beam_width, vocab_size]`.
sequence_lengths: The array of sequence lengths.
length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
+ coverage_penalty_weight: Float weight to penalize the coverage of source
+ sentence. Disabled with 0.0.
+ finished: A boolean tensor of shape `[batch_size, beam_width]` that
+ specifies which elements in the beam are finished already.
+ accumulated_attention_probs: Accumulated attention probabilities up to the
+ current time step, with shape `[batch_size, beam_width, max_time]` if
+ coverage_penalty_weight is not 0.0.
Returns:
- The scores normalized by the length_penalty.
+ The scores normalized by the length_penalty and coverage_penalty.
+
+ Raises:
+ ValueError: accumulated_attention_probs is None when coverage penalty is
+ enabled.
"""
length_penalty_ = _length_penalty(
sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
- return log_probs / length_penalty_
+ scores = log_probs / length_penalty_
+
+ coverage_penalty_weight = ops.convert_to_tensor(
+ coverage_penalty_weight, name="coverage_penalty_weight")
+ if coverage_penalty_weight.shape.ndims != 0:
+ raise ValueError("coverage_penalty_weight should be a scalar, "
+ "but saw shape: %s" % coverage_penalty_weight.shape)
+
+ if tensor_util.constant_value(coverage_penalty_weight) == 0.0:
+ return scores
+
+ if accumulated_attention_probs is None:
+ raise ValueError(
+ "accumulated_attention_probs can be None only if coverage penalty is "
+ "disabled.")
+
+ # Add source sequence length mask before computing coverage penalty.
+ accumulated_attention_probs = array_ops.where(
+ math_ops.equal(accumulated_attention_probs, 0.0),
+ array_ops.ones_like(accumulated_attention_probs),
+ accumulated_attention_probs)
+
+ # coverage penalty =
+ # sum over `max_time` {log(min(accumulated_attention_probs, 1.0))}
+ coverage_penalty = math_ops.reduce_sum(
+ math_ops.log(math_ops.minimum(accumulated_attention_probs, 1.0)), 2)
+ # Apply coverage penalty to finished predictions.
+ coverage_penalty *= math_ops.to_float(finished)
+ weighted_coverage_penalty = coverage_penalty * coverage_penalty_weight
+ # Reshape from [batch_size, beam_width] to [batch_size, beam_width, 1]
+ weighted_coverage_penalty = array_ops.expand_dims(
+ weighted_coverage_penalty, 2)
+ return scores + weighted_coverage_penalty
+
+
+def attention_probs_from_attn_state(attention_state):
+ """Calculates the average attention probabilities.
+
+ Args:
+ attention_state: An instance of `AttentionWrapperState`.
+
+ Returns:
+ The attention probabilities in the given AttentionWrapperState.
+ If there're multiple attention mechanisms, return the average value from
+ all attention mechanisms.
+ """
+ # Attention probabilities over time steps, with shape
+ # `[batch_size, beam_width, max_time]`.
+ attention_probs = attention_state.alignments
+ if isinstance(attention_probs, tuple):
+ attention_probs = [
+ array_ops.expand_dims(prob, -1) for prob in attention_probs]
+ attention_probs = array_ops.concat(attention_probs, -1)
+ attention_probs = math_ops.reduce_mean(attention_probs, -1)
+ return attention_probs
def _length_penalty(sequence_lengths, penalty_factor):
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.cc b/tensorflow/contrib/session_bundle/bundle_shim.cc
index 4fc36d85ed..c669ced997 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim.cc
@@ -355,11 +355,15 @@ Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir,
const std::unordered_set<string>& saved_model_tags,
- SavedModelBundle* saved_model_bundle) {
+ SavedModelBundle* saved_model_bundle, bool* is_session_bundle) {
+ if (is_session_bundle != nullptr) {
+ *is_session_bundle = false;
+ }
if (MaybeSavedModelDirectory(export_dir)) {
LOG(INFO)
<< "Attempting to load native SavedModelBundle in bundle-shim from: "
<< export_dir;
+
return LoadSavedModel(session_options, run_options, export_dir,
saved_model_tags, saved_model_bundle);
} else if (IsPossibleExportDirectory(export_dir)) {
@@ -368,6 +372,9 @@ Status LoadSessionBundleOrSavedModelBundle(
LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
"in bundle-shim from: "
<< export_dir;
+ if (is_session_bundle != nullptr) {
+ *is_session_bundle = true;
+ }
return LoadSavedModelFromLegacySessionBundlePath(
session_options, run_options, export_dir, saved_model_bundle);
}
diff --git a/tensorflow/contrib/session_bundle/bundle_shim.h b/tensorflow/contrib/session_bundle/bundle_shim.h
index 4628b6ab1b..7f0f9958d7 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim.h
+++ b/tensorflow/contrib/session_bundle/bundle_shim.h
@@ -59,11 +59,13 @@ Status ConvertSessionBundleToSavedModelBundle(
} // namespace internal
// Loads a SavedModel from either a session-bundle path or a SavedModel bundle
-// path.
+// path. If `is_session_bundle` is not a nullptr, sets it to `true` iff
+// SavedModel was up-converted and loaded from a SessionBundle.
+// `is_session_bundle` value should not be used if error is returned.
Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir, const std::unordered_set<string>& tags,
- SavedModelBundle* bundle);
+ SavedModelBundle* bundle, bool* is_session_bundle = nullptr);
} // namespace serving
} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
index 9a1dd9303f..815beb73a0 100644
--- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc
+++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc
@@ -63,12 +63,16 @@ void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
void LoadAndValidateSavedModelBundle(const string& export_dir,
const std::unordered_set<string>& tags,
- const string& signature_def_key) {
+ const string& signature_def_key,
+ bool expect_session_bundle) {
SessionOptions session_options;
RunOptions run_options;
SavedModelBundle saved_model_bundle;
+ bool is_session_bundle = false;
TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
- session_options, run_options, export_dir, tags, &saved_model_bundle));
+ session_options, run_options, export_dir, tags, &saved_model_bundle,
+ &is_session_bundle));
+ EXPECT_EQ(expect_session_bundle, is_session_bundle);
const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
const auto& signature_def_map = meta_graph_def.signature_def();
@@ -512,7 +516,8 @@ TEST(BundleShimTest, BasicExportSessionBundle) {
const string session_bundle_export_dir =
test_util::TestSrcDirPath(kSessionBundlePath);
LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
- kDefaultServingSignatureDefKey);
+ kDefaultServingSignatureDefKey,
+ /*expect_session_bundle=*/true);
// Verify that the named signature is also present.
SessionOptions session_options;
@@ -558,7 +563,8 @@ TEST(BundleShimTest, BasicExportSavedModel) {
const string saved_model_bundle_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
- {kSavedModelTagServe}, "regress_x_to_y");
+ {kSavedModelTagServe}, "regress_x_to_y",
+ /*expect_session_bundle=*/false);
}
// Checks a basic load fails with an invalid export path.
diff --git a/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py b/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
index 1bb6fbc570..795de6a408 100644
--- a/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/dataset_data_provider_test.py
@@ -88,7 +88,7 @@ class DatasetDataProviderTest(test.TestCase):
height = 300
width = 280
- with self.test_session():
+ with self.cached_session():
test_dataset = _create_tfrecord_dataset(dataset_dir)
provider = dataset_data_provider.DatasetDataProvider(test_dataset)
key, image, label = provider.get(['record_key', 'image', 'label'])
@@ -111,7 +111,7 @@ class DatasetDataProviderTest(test.TestCase):
height = 300
width = 280
- with self.test_session():
+ with self.cached_session():
provider = dataset_data_provider.DatasetDataProvider(
_create_tfrecord_dataset(dataset_dir))
[image] = provider.get(['image'])
@@ -128,7 +128,7 @@ class DatasetDataProviderTest(test.TestCase):
dataset_dir = tempfile.mkdtemp(prefix=os.path.join(self.get_temp_dir(),
'tfrecord_dataset'))
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
dataset_data_provider.DatasetDataProvider(
_create_tfrecord_dataset(dataset_dir), record_key='image')
diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
index ea8cc0ff61..c457d44e07 100644
--- a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py
@@ -39,7 +39,7 @@ class ParallelReaderTest(test.TestCase):
ops.reset_default_graph()
def _verify_all_data_sources_read(self, shared_queue):
- with self.test_session():
+ with self.cached_session():
tfrecord_paths = test_utils.create_tfrecord_files(
self.get_temp_dir(), num_files=3)
@@ -76,7 +76,7 @@ class ParallelReaderTest(test.TestCase):
self.assertEquals(count0 + count1 + count2, num_reads)
def _verify_read_up_to_out(self, shared_queue):
- with self.test_session():
+ with self.cached_session():
num_files = 3
num_records_per_file = 7
tfrecord_paths = test_utils.create_tfrecord_files(
@@ -161,7 +161,7 @@ class ParallelReadTest(test.TestCase):
ops.reset_default_graph()
def testTFRecordReader(self):
- with self.test_session():
+ with self.cached_session():
self._tfrecord_paths = test_utils.create_tfrecord_files(
self.get_temp_dir(), num_files=3)
@@ -188,7 +188,7 @@ class SinglePassReadTest(test.TestCase):
ops.reset_default_graph()
def testOutOfRangeError(self):
- with self.test_session():
+ with self.cached_session():
[tfrecord_path] = test_utils.create_tfrecord_files(
self.get_temp_dir(), num_files=1)
@@ -196,7 +196,7 @@ class SinglePassReadTest(test.TestCase):
tfrecord_path, reader_class=io_ops.TFRecordReader)
init_op = variables.local_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with queues.QueueRunners(sess):
num_reads = 11
@@ -205,7 +205,7 @@ class SinglePassReadTest(test.TestCase):
sess.run([key, value])
def testTFRecordReader(self):
- with self.test_session():
+ with self.cached_session():
[tfrecord_path] = test_utils.create_tfrecord_files(
self.get_temp_dir(), num_files=1)
@@ -213,7 +213,7 @@ class SinglePassReadTest(test.TestCase):
tfrecord_path, reader_class=io_ops.TFRecordReader)
init_op = variables.local_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with queues.QueueRunners(sess):
flowers = 0
diff --git a/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py b/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
index 6c3e57c47d..7caa42dcb9 100644
--- a/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/prefetch_queue_test.py
@@ -37,7 +37,7 @@ from tensorflow.python.training import queue_runner_impl
class PrefetchQueueTest(test.TestCase):
def testOneThread(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
image_size = 32
num_batches = 5
@@ -74,7 +74,7 @@ class PrefetchQueueTest(test.TestCase):
thread.join()
def testMultiThread(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
image_size = 32
num_batches = 5
@@ -114,7 +114,7 @@ class PrefetchQueueTest(test.TestCase):
thread.join()
def testMultipleDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = 10
image_size = 32
num_batches = 4
@@ -162,7 +162,7 @@ class PrefetchQueueTest(test.TestCase):
prefetch_queue.prefetch_queue([variable_tensor])
def testDynamicPad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create 3 tensors of variable but compatible shapes.
var_shape = [None, 2]
p1 = constant_op.constant([[1, 2], [3, 4]])
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index 826242c9d7..3114949b82 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -45,7 +45,7 @@ class TFExampleDecoderTest(test.TestCase):
int64_list=feature_pb2.Int64List(value=ndarray.flatten().tolist()))
def _EncodedBytesFeature(self, tf_encoded):
- with self.test_session():
+ with self.cached_session():
encoded = tf_encoded.eval()
def BytesList(value):
@@ -133,7 +133,7 @@ class TFExampleDecoderTest(test.TestCase):
tf_image = self.DecodeExample(serialized_example, item_handler,
image_format)
- with self.test_session():
+ with self.cached_session():
decoded_image = tf_image.eval()
# We need to recast them here to avoid some issues with uint8.
@@ -265,7 +265,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels':
@@ -296,7 +296,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.float32)
@@ -319,7 +319,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.int64)
@@ -342,7 +342,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -366,7 +366,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels':
@@ -390,7 +390,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -423,7 +423,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'image': parsing_ops.VarLenFeature(dtype=dtypes.float32),
@@ -468,7 +468,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'image': parsing_ops.VarLenFeature(dtype=dtypes.float32),
@@ -505,7 +505,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -536,7 +536,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -567,7 +567,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -598,7 +598,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
@@ -625,7 +625,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
@@ -657,7 +657,7 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
@@ -692,7 +692,7 @@ class TFExampleDecoderTest(test.TestCase):
image, serialized_example = self.GenerateImage(
image_format=image_encoding, image_shape=image_shape)
- with self.test_session():
+ with self.cached_session():
def ConditionalDecoding(keys_to_tensors):
"""See base class."""
@@ -759,7 +759,7 @@ class TFExampleDecoderTest(test.TestCase):
}))
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
@@ -800,7 +800,7 @@ class TFExampleDecoderTest(test.TestCase):
}))
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
@@ -837,7 +837,7 @@ class TFExampleDecoderTest(test.TestCase):
image, _ = self.GenerateImage(
image_format=image_format, image_shape=image_shape)
tf_encoded = self._Encoder(image, image_format)
- with self.test_session():
+ with self.cached_session():
tf_string = tf_encoded.eval()
example = example_pb2.Example(
@@ -852,7 +852,7 @@ class TFExampleDecoderTest(test.TestCase):
}))
serialized_example = example.SerializeToString()
- with self.test_session():
+ with self.cached_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
decoder = tfexample_decoder.TFExampleDecoder(
@@ -885,7 +885,7 @@ class TFExampleDecoderTest(test.TestCase):
table = lookup_ops.index_table_from_tensor(
constant_op.constant(['dog', 'guinea pig', 'cat']))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(lookup_ops.tables_initializer())
serialized_example = array_ops.reshape(serialized_example, shape=[])
@@ -943,7 +943,7 @@ class TFExampleDecoderTest(test.TestCase):
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
obtained_class_ids_each_example = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(lookup_ops.tables_initializer())
for example in [example1, example2, example3]:
serialized_example = array_ops.reshape(
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
index 4707dc2229..8fcd7aeef6 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
@@ -47,7 +47,7 @@ def _get_lanczos_tests(dtype_, use_static_shape_, shape_, orthogonalize_,
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
tol = 1e-12 if dtype_ == np.float64 else 1e-5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np)
else:
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
index a73642716b..2a9100903a 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/least_squares_test.py
@@ -47,7 +47,7 @@ def _get_least_squares_tests(dtype_, use_static_shape_, shape_):
low=-1.0, high=1.0, size=shape_[0]).astype(dtype_)
tol = 1e-12 if dtype_ == np.float64 else 1e-6
max_iter = 20
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np)
rhs = constant_op.constant(rhs_np)
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
index a1282847be..a0e6eb87bc 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
@@ -54,7 +54,7 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_):
x_np = np.zeros_like(rhs_np)
tol = 1e-6 if dtype_ == np.float64 else 1e-3
max_iter = 20
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np)
rhs = constant_op.constant(rhs_np)
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
index 5d7534657b..57b4996689 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
@@ -33,7 +33,7 @@ class UtilTest(test.TestCase):
a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
x_np = np.array([[2.], [-3.]], dtype=dtype)
y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np, dtype=dtype)
x = constant_op.constant(x_np, dtype=dtype)
@@ -68,7 +68,7 @@ class UtilTest(test.TestCase):
a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
x_np = np.array([[2.], [-3.]], dtype=dtype)
y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np, dtype=dtype)
x = constant_op.constant(x_np, dtype=dtype)
@@ -101,7 +101,7 @@ class UtilTest(test.TestCase):
self._testIdentityOperator(False)
def testL2Norm(self):
- with self.test_session():
+ with self.cached_session():
x_np = np.array([[2], [-3.], [5.]])
x_norm_np = np.linalg.norm(x_np)
x_normalized_np = x_np / x_norm_np
diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
index e4db5f2e3c..e6a0b30567 100644
--- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
+++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py
@@ -38,7 +38,7 @@ class StatSummarizerTest(test.TestCase):
graph_def = graph.as_graph_def()
ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
for _ in range(20):
diff --git a/tensorflow/contrib/summary/summary_ops_graph_test.py b/tensorflow/contrib/summary/summary_ops_graph_test.py
index ae8336daaf..807741e05f 100644
--- a/tensorflow/contrib/summary/summary_ops_graph_test.py
+++ b/tensorflow/contrib/summary/summary_ops_graph_test.py
@@ -52,7 +52,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
summary_ops.histogram('histogram', [1.0], step=1)
summary_ops.image('image', [[[[1.0]]]], step=1)
summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
# The working condition of the ops is tested in the C++ test so we just
@@ -64,7 +64,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
writer = summary_ops.create_file_writer(logdir, max_queue=0)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
events = summary_test_util.events_from_logdir(logdir)
@@ -77,7 +77,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
with writer.as_default(), summary_ops.always_record_summaries():
with ops.name_scope('scope'):
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
sess.run(summary_ops.all_summary_ops())
events = summary_test_util.events_from_logdir(logdir)
@@ -90,7 +90,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
writer = summary_ops.create_file_writer(logdir, max_queue=0)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(summary_ops.summary_writer_initializer_op())
step, _ = sess.run(
@@ -105,7 +105,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
logdir, max_queue=1, flush_millis=999999)
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
# Note: First tf.Event is always file_version.
@@ -123,7 +123,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
with writer.as_default(), summary_ops.always_record_summaries():
summary_ops.scalar('scalar', 2.0, step=1)
flush_op = summary_ops.flush()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
# Note: First tf.Event is always file_version.
@@ -157,7 +157,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
with writer3.as_default():
summary_ops.scalar('three', 3.0, step=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Run init ops across writers sequentially to avoid race condition.
# TODO(nickfelt): fix race condition in resource manager lookup or create
sess.run(writer1.init())
@@ -191,7 +191,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
logdir, max_queue=100, flush_millis=1000000)
with writer.as_default():
summary_ops.scalar('one', 1.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
self.assertEqual(1, get_total()) # file_version Event
@@ -219,7 +219,7 @@ class GraphFileTest(test_util.TensorFlowTestCase):
logdir, max_queue=100, flush_millis=1000000)
with writer.as_default():
summary_ops.scalar('one', 1.0, step=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(summary_ops.summary_writer_initializer_op())
get_total = lambda: len(summary_test_util.events_from_logdir(logdir))
self.assertEqual(1, get_total()) # file_version Event
@@ -241,7 +241,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
training_util.get_or_create_global_step()
name = 'hi'
graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
- with self.test_session():
+ with self.cached_session():
with self.create_db_writer().as_default():
summary_ops.initialize(graph=graph)
six.assertCountEqual(self, [name],
@@ -249,7 +249,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
def testScalarSummary(self):
"""Test record_summaries_every_n_global_steps and all_summaries()."""
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
global_step = training_util.get_or_create_global_step()
global_step.initializer.run()
with ops.device('/cpu:0'):
@@ -280,7 +280,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
def testScalarSummaryNameScope(self):
"""Test record_summaries_every_n_global_steps and all_summaries()."""
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
global_step = training_util.get_or_create_global_step()
global_step.initializer.run()
with ops.device('/cpu:0'):
@@ -311,7 +311,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
self.assertEqual(events[1].summary.value[0].tag, 'scope/my_scalar')
def testSummaryGraphModeCond(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with summary_ops.create_file_writer(
@@ -332,7 +332,7 @@ class GraphDbTest(summary_test_util.SummaryDbTest):
self.assertEqual(events[1].summary.value[0].tag, 'cond/scalar')
def testSummaryGraphModeWhile(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
training_util.get_or_create_global_step()
logdir = tempfile.mkdtemp()
with summary_ops.create_file_writer(
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 00c855daa3..398ac314f4 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -518,7 +518,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":client_lib",
- "//tensorflow/contrib/estimator:head",
+ "//tensorflow/contrib/estimator:estimator_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
index aa30919167..d49928e3f1 100644
--- a/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
+++ b/tensorflow/contrib/tensor_forest/client/eval_metrics_test.py
@@ -32,7 +32,7 @@ class EvalMetricsTest(test_util.TensorFlowTestCase):
[0.9, 0.8, 0.2], [0.6, 0.4, 0.8]])
targets = constant_op.constant([[0], [2], [1], [1]])
in_top_2_op, update_op = top_2_fn(probabilities, targets)
- with self.test_session():
+ with self.cached_session():
# initializes internal accuracy vars
variables.local_variables_initializer().run()
# need to call in order to run the in_top_2_op internal operations because
@@ -49,7 +49,7 @@ class EvalMetricsTest(test_util.TensorFlowTestCase):
[0.3, 0.6, 0.9, 0.4, 0.8, 0.6]])
targets = constant_op.constant([3, 0, 2, 5, 1])
in_top_3_op, update_op = top_3_fn(probabilities, targets)
- with self.test_session():
+ with self.cached_session():
# initializes internal accuracy vars
variables.local_variables_initializer().run()
# need to call in order to run the in_top_3_op internal operations because
@@ -61,7 +61,7 @@ class EvalMetricsTest(test_util.TensorFlowTestCase):
predictions = constant_op.constant([0, 1, 3, 6, 5, 2, 7, 6, 4, 9])
targets = constant_op.constant([0, 1, 4, 6, 5, 1, 7, 5, 4, 8])
accuracy_op, update_op = eval_metrics._accuracy(predictions, targets)
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
# need to call in order to run the accuracy_op internal operations because
# it is a streaming function
@@ -74,7 +74,7 @@ class EvalMetricsTest(test_util.TensorFlowTestCase):
targets = constant_op.constant(
[1.0, 4.3, 2.6, 0.5, 1.1, 0.7, 5.1, 3.4, 1.8])
r2_op, update_op = eval_metrics._r2(scores, targets)
- with self.test_session():
+ with self.cached_session():
# initializes internal accuracy vars
variables.local_variables_initializer().run()
# need to call in order to run the r2_op internal operations because
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
index f80a34ece6..fe2c91c104 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
@@ -246,7 +246,8 @@ class ProcessInputOp : public OpKernel {
const Tensor& input_weights = context->input(7);
const Tensor& leaf_ids_tensor = context->input(8);
- std::unique_ptr<TensorDataSet> data_set(new TensorDataSet(input_spec_, 0));
+ std::unique_ptr<TensorDataSet> data_set(
+ new TensorDataSet(input_spec_, random_seed_));
data_set->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
diff --git a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
index cefcc96051..dd5d028314 100644
--- a/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/tree_utils.cc
@@ -67,11 +67,11 @@ float ClassificationSplitScore(
const Eigen::Tensor<float, 1, Eigen::RowMajor>& splits,
const Eigen::Tensor<float, 1, Eigen::RowMajor>& rights, int32 num_classes,
int i) {
- Eigen::array<int, 1> offsets;
+ Eigen::array<Eigen::Index, 1> offsets;
// Class counts are stored with the total in [0], so the length of each
// count vector is num_classes + 1.
offsets[0] = i * (num_classes + 1) + 1;
- Eigen::array<int, 1> extents;
+ Eigen::array<Eigen::Index, 1> extents;
extents[0] = num_classes;
return WeightedGiniImpurity(splits.slice(offsets, extents)) +
WeightedGiniImpurity(rights.slice(offsets, extents));
@@ -97,7 +97,7 @@ void GetTwoBestClassification(const Tensor& total_counts,
// arguments to ClassificationSplitScore.
const Eigen::Tensor<float, 1, Eigen::RowMajor> splits =
split_counts.Slice(accumulator, accumulator + 1).unaligned_flat<float>();
- Eigen::array<int, 1> bcast;
+ Eigen::array<Eigen::Index, 1> bcast;
bcast[0] = num_splits;
const Eigen::Tensor<float, 1, Eigen::RowMajor> rights =
tc.broadcast(bcast) - splits;
@@ -130,8 +130,8 @@ float RegressionSplitScore(
const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_sums,
const Eigen::Tensor<float, 1, Eigen::RowMajor>& right_squares,
int32 accumulator, int32 num_regression_dims, int i) {
- Eigen::array<int, 1> offsets = {i * num_regression_dims + 1};
- Eigen::array<int, 1> extents = {num_regression_dims - 1};
+ Eigen::array<Eigen::Index, 1> offsets = {i * num_regression_dims + 1};
+ Eigen::array<Eigen::Index, 1> extents = {num_regression_dims - 1};
float left_count = splits_count_accessor(accumulator, i, 0);
float right_count = totals_count_accessor(accumulator, 0) - left_count;
@@ -178,7 +178,7 @@ void GetTwoBestRegression(const Tensor& total_sums, const Tensor& total_squares,
const auto splits_count_accessor = split_sums.tensor<float, 3>();
const auto totals_count_accessor = total_sums.tensor<float, 2>();
- Eigen::array<int, 1> bcast;
+ Eigen::array<Eigen::Index, 1> bcast;
bcast[0] = num_splits;
const auto right_sums = tc_sum.broadcast(bcast) - splits_sum;
const auto right_squares = tc_square.broadcast(bcast) - splits_square;
diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
index e429d12e96..1c4e18dbda 100644
--- a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
+++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py
@@ -32,7 +32,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
indices = [[1], [10]]
updates = [100., 200.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual(
@@ -45,7 +45,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100., 200.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual([[[1., 102., 3.], [4., 5., 6.]],
@@ -57,7 +57,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
indices = []
updates = []
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual(init_val, input_data.eval())
@@ -67,7 +67,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
input_data = variables.Variable(init_val)
indices = [[0, 0, 1], [1, 1, 2]]
updates = [100.]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
with self.assertRaisesOpError(
'Number of updates should be same as number of indices.'):
@@ -80,7 +80,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase):
indices = [[0, 0], [1, 1]]
updates = [[100., 200., 300.], [400., 500., 600.]]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run()
self.assertAllEqual([[[101., 202., 303.], [4., 5., 6.]],
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 1c9c81827e..e0f0c0d4ff 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -149,7 +149,7 @@ class TensorForestTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
self.assertTrue(isinstance(var, ops.Tensor))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
resources.initialize_resources(resources.shared_resources()).run()
self.assertEquals(probs.eval().shape, (4, 2))
diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD
index 2b6a2b2f3c..7f0b3255ed 100644
--- a/tensorflow/contrib/tensorboard/BUILD
+++ b/tensorflow/contrib/tensorboard/BUILD
@@ -32,7 +32,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":projector",
- ":trace",
],
)
@@ -60,33 +59,3 @@ py_test(
"//tensorflow/python:summary",
],
)
-
-# API methods and protos in `tf.contrib.tensorboard.plugins.trace` package.
-py_library(
- name = "trace",
- srcs = glob(
- ["plugins/trace/**/*.py"],
- exclude = ["**/*test*"],
- ),
- srcs_version = "PY2AND3",
- deps = [
- ":protos_all_py",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:lib",
- "//tensorflow/python:platform",
- ],
-)
-
-py_test(
- name = "trace_test",
- size = "small",
- srcs = ["plugins/trace/trace_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"],
- deps = [
- ":trace",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:platform",
- ],
-)
diff --git a/tensorflow/contrib/tensorboard/db/loader.cc b/tensorflow/contrib/tensorboard/db/loader.cc
index 4d7337a53d..6439328022 100644
--- a/tensorflow/contrib/tensorboard/db/loader.cc
+++ b/tensorflow/contrib/tensorboard/db/loader.cc
@@ -111,10 +111,10 @@ int main(int argc, char* argv[]) {
++records;
}
uint64 elapsed = env->NowMicros() - start;
+ uint64 bps = (elapsed == 0 ? offset : static_cast<uint64>(
+ offset / (elapsed / 1000000.0)));
LOG(INFO) << "Loaded " << AddCommas(offset) << " bytes with "
- << AddCommas(records) << " records at "
- << AddCommas(offset / (elapsed / 1000000)) << " bps";
-
+ << AddCommas(records) << " records at " << AddCommas(bps) << " bps";
return 0;
}
diff --git a/tensorflow/contrib/tensorboard/plugins/__init__.py b/tensorflow/contrib/tensorboard/plugins/__init__.py
index 41aa77910c..4ba469eb52 100644
--- a/tensorflow/contrib/tensorboard/plugins/__init__.py
+++ b/tensorflow/contrib/tensorboard/plugins/__init__.py
@@ -20,4 +20,4 @@ from __future__ import print_function
# Add projects here, they will show up under tf.contrib.tensorboard.plugins
from tensorflow.contrib.tensorboard.plugins import projector
-from tensorflow.contrib.tensorboard.plugins import trace
+
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace.py b/tensorflow/contrib/tensorboard/plugins/trace/trace.py
deleted file mode 100644
index 07e5316b8b..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Stores debugging information regarding TensorFlow model."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-import parser
-import re
-import token
-
-from google.protobuf import json_format
-
-from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import TraceInfo
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import gfile
-
-# List of regex patterns that match files in the core tensorflow library.
-TF_LIB_REGEX_FPATHS = [os.sep + os.path.join('tensorflow', 'python')]
-
-LEFT_TOKENS = [token.LPAR, token.LSQB, token.LBRACE]
-RIGHT_TOKENS = [token.RPAR, token.RSQB, token.RBRACE]
-TOKENS = LEFT_TOKENS + RIGHT_TOKENS
-
-
-def store_trace_info(output_file_path,
- graph=None,
- ignore_regex_fpaths=None):
- """Collects and stores trace information for a TensorFlow model.
-
- The output proto is stored in json format.
-
- Args:
- output_file_path: The path where to store the output proto.
- graph: Optional. The data flow graph. Defaults to `tf.get_default_graph()`.
- ignore_regex_fpaths: Optional. Files whose path matches any of the regexes
- in this list will be ignored. Defaults to patterns that match the core
- tensorflow python library.
- """
- graph = graph or ops.get_default_graph()
-
- if not ignore_regex_fpaths:
- ignore_regex_fpaths = TF_LIB_REGEX_FPATHS
-
- trace_info = TraceInfo()
- # Extract trace information for every op in the graph.
- source_fpaths = set()
- for op in graph.get_operations():
- op_info = trace_info.ops.add()
- op_info.name = op.name
- op_info.op_type = op.type
- op_info.device = op.device
- for trace in op.traceback:
- fname, lineno, _, _ = trace
- # Ignore traces in specified file paths.
- if os.path.isabs(fname) and not _ignore_file_path(fname,
- ignore_regex_fpaths):
- line_trace = op_info.traceback.add()
- line_trace.file_path = fname
- line_trace.line_number = lineno
- source_fpaths.add(fname)
- _add_data_from_tensors(op.inputs, op_info.inputs)
- _add_data_from_tensors(op.outputs, op_info.outputs)
-
- # Read the source files involved in the graph construction.
- for fpath in source_fpaths:
- file_info = trace_info.files.add()
-
- with gfile.Open(fpath, 'r') as f:
- source = f.read()
-
- file_info.file_path = fpath
- file_info.source_code = source
-
- line2start = find_multiline_statements(source)
-
- for key, value in line2start.items():
- file_info.multiline_statements[key] = value
-
- # Make sure the directory for the output file exists.
- output_file_path = os.path.expanduser(output_file_path)
- output_dir = os.path.dirname(output_file_path)
- if not gfile.Exists(output_dir):
- gfile.MakeDirs(output_dir)
-
- # Store the debug information.
- with gfile.Open(output_file_path, 'w') as f:
- f.write(json_format.MessageToJson(trace_info))
-
-
-def find_multiline_statements(source):
- """Parses the python source and finds multiline statements.
-
- Based on counting the number of open and closed parenthesis on each line.
-
- Args:
- source: The source code string.
-
- Returns:
- A dict that maps a line index A to a line index B, where A is the end of a
- multiline statement and B is the start. Line indexing is 0-based.
- """
- # Get the AST.
- tree = parser.suite(source)
- line2paren_count = [0] * (source.count('\n') + 1)
- _count_brackets_braces_parenthesis(tree.totuple(True), line2paren_count)
-
- line2start = {}
- for end in range(len(line2paren_count)):
- if line2paren_count[end] >= 0:
- # This is not the end of a multiline statement.
- continue
- cumulative_paren_count = 0
- for start in range(end, -1, -1):
- cumulative_paren_count += line2paren_count[start]
- if cumulative_paren_count == 0:
- line2start[end] = start
- break
- return line2start
-
-
-def _add_data_from_tensors(tensors, info):
- for t in tensors:
- tensor_info = info.add()
-
- shape = t.get_shape()
- if shape.ndims:
- shape = [(-1 if s is None else s) for s in shape.as_list()]
- tensor_info.shape.extend(shape)
- tensor_info.dtype = t.dtype.name
- tensor_info.num_bytes_per_elem = t.dtype.size
-
- for c in t.consumers():
- tensor_info.consumers.append(c.name)
-
-
-def _ignore_file_path(fname, ignore_regex_fpaths):
- for regex_pattern in ignore_regex_fpaths:
- if re.search(regex_pattern, fname):
- return True
- return False
-
-
-def _count_brackets_braces_parenthesis(node, line2par):
- if isinstance(node[1], tuple):
- for child in node[1:]:
- _count_brackets_braces_parenthesis(child, line2par)
- else:
- tok = node[0]
- if tok in TOKENS:
- lineno = node[2]
- line2par[lineno - 1] += (1 if tok in LEFT_TOKENS else -1)
- return line2par
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto b/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto
deleted file mode 100644
index 9f20becb0f..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace_info.proto
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-syntax = "proto3";
-
-package tensorflow.contrib.tensorboard;
-
-message TraceInfo {
- repeated OpInfo ops = 1;
- repeated FileInfo files = 2;
-}
-
-message OpInfo {
- string name = 1;
- string op_type = 2;
- string device = 3;
- repeated LineTrace traceback = 4;
- repeated TensorInfo inputs = 5;
- repeated TensorInfo outputs = 6;
-}
-
-message LineTrace {
- // Absolute file path.
- string file_path = 1;
- // 1-based line number.
- uint32 line_number = 2;
-}
-
-message TensorInfo {
- // Size of the tensor for each dimension. Value of -1 denotes "unknown"
- // size for that dimension.
- repeated int32 shape = 1;
- // The data type of the tensor.
- string dtype = 2;
- // Number of bytes per element in the tensor.
- uint32 num_bytes_per_elem = 3;
- // List of operation names that consume this tensor.
- repeated string consumers = 4;
-}
-
-message FileInfo {
- // Absolute file path to the source code.
- string file_path = 1;
- string source_code = 2;
- // Map from end of statement to start of statement. End and start are 0-based
- // line indexes.
- map<uint32, uint32> multiline_statements = 3;
-}
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py b/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py
deleted file mode 100644
index d580f04c5f..0000000000
--- a/tensorflow/contrib/tensorboard/plugins/trace/trace_test.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tensorflow.contrib.tensorboard.plugins.trace package."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tempfile
-
-from google.protobuf import json_format
-
-from tensorflow.contrib.tensorboard.plugins import trace
-from tensorflow.python.framework import constant_op
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-
-
-class TraceTest(test.TestCase):
-
- def setUp(self):
- self._temp_dir = tempfile.mkdtemp()
- self._temp_trace_json = self._temp_dir + 'trace.json'
-
- def tearDown(self):
- gfile.DeleteRecursively(self._temp_dir)
-
- def testEmptyGraph(self):
- trace_info = self._store_and_read_trace_info()
- self.assertEqual(len(trace_info.ops), 0)
-
- def testHasSourceCodeOfThisFile(self):
- constant_op.constant(0)
- trace_info = self._store_and_read_trace_info()
-
- self.assertTrue(trace_info.files)
- for file_info in trace_info.files:
- if file_info.file_path.endswith('trace_test.py'):
- return
- self.fail('trace_test file not found in the trace info json')
-
- def testHasTheConstantOp(self):
- constant_op.constant(0)
- trace_info = self._store_and_read_trace_info()
-
- self.assertTrue(trace_info.ops)
-
- for op in trace_info.ops:
- if op.op_type == 'Const':
- return
- self.fail('Could not find operation of type `Const` in the graph')
-
- def testMultilineStatements(self):
- source = """def test():
- a(4,
- 3,
- 1)
-
- b(3, 4, 5)
-
- c((4, 3),
- (),
- )
- """
- line2start = trace.find_multiline_statements(source)
-
- self.assertEqual(line2start[3], 1)
- self.assertEqual(line2start[9], 7)
- self.assertEqual(len(line2start), 2)
-
- def _store_and_read_trace_info(self):
- trace.store_trace_info(self._temp_trace_json)
- trace_info = trace.TraceInfo()
-
- with gfile.Open(self._temp_trace_json) as f:
- text = f.read()
- json_format.Parse(text, trace_info)
-
- return trace_info
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 122a67a407..9e8979bce4 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -19,6 +19,7 @@ load(
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -181,7 +182,12 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":wrap_conversion",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python:session",
"//tensorflow/python:tf_optimizer",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
],
)
@@ -410,6 +416,31 @@ py_library(
],
)
+cuda_py_test(
+ name = "trt_convert_test",
+ srcs = ["python/trt_convert_test.py"],
+ additional_deps = [
+ ":trt_convert_py",
+ ":trt_ops_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow/python/saved_model:utils",
+ "//tensorflow/python/tools:freeze_graph_lib",
+ "//tensorflow/python/tools:saved_model_utils",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_windows",
+ "nomac",
+ ],
+)
+
cuda_py_tests(
name = "tf_trt_integration_test",
srcs = [
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/contrib/tensorrt/README.md
index 687dee07e1..caf8b6db0d 100644
--- a/tensorflow/contrib/tensorrt/README.md
+++ b/tensorflow/contrib/tensorrt/README.md
@@ -26,4 +26,4 @@ available. An example use can be found in test/test_tftrt.py script
In order to make use of TensorRT integration, you will need a local installation
of TensorRT 3.0.4 from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt).
Installation instructions for compatibility with TensorFlow are provided on the
-[TensorFlow Installation page](https://www.tensorflow.org/install/install_linux#nvidia_requirements_to_run_tensorflow_with_gpu_support).
+[TensorFlow GPU support](https://www.tensorflow.org/install/gpu) guide.
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index b019c99882..7ad9bf22d3 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -678,7 +678,7 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
// Function to construct a funcdef from the segment and add it to the graph.
tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
tensorflow::Graph* graph, const tensorflow::GraphDef& segment,
- const string& name) {
+ const string& engine_name) {
tensorflow::Graph sgraph(graph->flib_def());
tensorflow::GraphConstructorOptions gcopts;
TF_RETURN_IF_ERROR(
@@ -761,9 +761,9 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
tensorflow::FunctionDefLibrary fdeflib;
auto native_segment = fdeflib.add_function();
TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
- sgraph, StrCat(name, "_native_segment"), native_segment));
+ sgraph, StrCat(engine_name, "_native_segment"), native_segment));
if (VLOG_IS_ON(7)) {
- VLOG(7) << name << " Function_Def ";
+ VLOG(7) << engine_name << " Function_Def ";
VLOG(7) << native_segment->DebugString();
}
VLOG(1) << "Adding funcdef to graphlib";
@@ -780,12 +780,12 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
// If device is not set, use the first found GPU device for the conversion.
for (int tf_gpu_id_value = 0; tf_gpu_id_value < 100; ++tf_gpu_id_value) {
TfGpuId tf_gpu_id(tf_gpu_id_value);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (s.ok()) {
VLOG(1) << "Found TF GPU " << tf_gpu_id.value() << " at cuda device "
- << cuda_gpu_id.value();
- cuda_device_id = cuda_gpu_id.value();
+ << platform_gpu_id.value();
+ cuda_device_id = platform_gpu_id.value();
GPUOptions gpu_options;
// If the TF to Cuda gpu id mapping exist, the device and corresponding
// allocator must have been initialized already, so the
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index c98b07ad8b..0ce891782e 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -693,8 +693,15 @@ class Converter {
// TODO(jie): tf protobuf seems to be omitting the :0 suffix
string output_name = node_def.name();
if (i != 0) output_name = StrCat(output_name, ":", i);
+ // We need to check the name before setting it. For Identity op where the
+ // output is the input, if its input is one of the engine input, setting
+ // the name here will overwrite engine input bindings which will cause
+ // runtime error.
if (output.is_tensor()) {
- output.tensor()->setName(output_name.c_str());
+ const char* tensor_name = output.tensor()->getName();
+ if (tensor_name == nullptr || std::strlen(tensor_name) == 0) {
+ output.tensor()->setName(output_name.c_str());
+ }
}
VLOG(2) << "Adding out tensor " << output_name << ": "
<< output.DebugString();
@@ -779,12 +786,11 @@ class Converter {
// skip control nodes
if (input_name[0] == '^') continue;
string name = input_name;
- auto first = name.find_first_of(':');
- // TODO(aaroey): why removing the colon but not the zero? A bug?
+ auto last = name.find_last_of(':');
// TODO(aaroey): use TensorId
- if (first != string::npos && first + 2 == name.size() &&
- name[first + 1] == '0') {
- name.erase(first);
+ if (last != string::npos && last + 2 == name.size() &&
+ name[last + 1] == '0') {
+ name.erase(last);
}
if (trt_tensors_.count(name)) {
@@ -2697,7 +2703,6 @@ tensorflow::Status ConvertGraphDefToEngine(
TrtUniquePtrType<nvinfer1::IBuilder> builder(
nvinfer1::createInferBuilder(*logger));
builder->setMaxBatchSize(max_batch_size);
- // TODO(aaroey): use the allocator to allocate the TRT workspace.
builder->setMaxWorkspaceSize(max_workspace_size_bytes);
#if NV_TENSORRT_MAJOR > 3
builder->setGpuAllocator(allocator);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 2b42d81f47..88cf8d5980 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -565,21 +565,22 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
new TRTInt8Calibrator(device_buffers_, batch_size, name()));
const string label(name());
auto segment_graph = &segment_graph_;
- const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
- if (cuda_gpu_id < 0) {
+ const int platform_gpu_id =
+ ctx->device()->tensorflow_gpu_device_info()->gpu_id;
+ if (platform_gpu_id < 0) {
LOG(ERROR) << "Can't get gpu_device_info from context->device()";
return tensorflow::errors::InvalidArgument(
"Context->device doesn't contain device info!");
}
const int64 workspace_size_bytes = workspace_size_;
cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
- cuda_gpu_id, workspace_size_bytes]() {
- VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
+ platform_gpu_id, workspace_size_bytes]() {
+ VLOG(0) << "Starting calibration thread on device " << platform_gpu_id
<< ", Calibration Resource @ " << cres;
- auto err = cudaSetDevice(cuda_gpu_id);
+ auto err = cudaSetDevice(platform_gpu_id);
if (err != cudaSuccess) {
// TODO(aaroey): should return error here.
- LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
+ LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
<< " in calibration thread";
}
// ConvertGraphDefToEngine() will try to build the engine. This thread
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 4116f2fe30..369e73b5a6 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long
import six as _six
+# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
@@ -28,55 +28,179 @@ from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_vers
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
+# pylint: enable=unused-import,line-too-long
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
-# pylint: enable=unused-import,line-too-long
+
+if _six.PY2:
+ _to_bytes = lambda s: s
+ _to_string = lambda s: s
+else:
+ _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape")
+ _to_string = lambda s: s.decode("utf-8")
+
+
+class TrtPrecisionMode(object):
+ FP32 = "FP32"
+ FP16 = "FP16"
+ INT8 = "INT8"
+
+ @staticmethod
+ def supported_precision_modes():
+ return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
+
+
+def tensorrt_rewriter_config(max_batch_size=1,
+ max_workspace_size_bytes=2 << 20,
+ precision_mode=TrtPrecisionMode.FP32,
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batch_sizes=None):
+ """Returns a RewriterConfig proto for TRT transformation.
+
+ Args:
+ max_batch_size: max size for the input batch
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
+ minimum_segment_size: the minimum number of nodes required for a subgraph to
+ be replaced by TRTEngineOp.
+ is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
+ network and engine at run time.
+ maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batch_sizes: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be smaller than maximum_cached_engines, and the dynamic TRT op will
+ use this list to determine the batch sizes of the cached engines, instead
+ of making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+
+ Returns:
+ A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
+
+ Raises:
+ TypeError: if the provided precision mode is invalid.
+ ValueError: if len(cached_engine_batch_sizes) exceed maximum_cached_engines.
+ """
+ if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes():
+ raise ValueError(("precision mode '{}' is not supported."
+ "It should be one of {}").format(
+ precision_mode,
+ TrtPrecisionMode.supported_precision_modes))
+
+ rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+ rewriter_cfg.optimizers.extend(["constfold", "layout"])
+ optimizer = rewriter_cfg.custom_optimizers.add()
+ optimizer.name = "TensorRTOptimizer"
+ optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
+ optimizer.parameter_map["max_batch_size"].i = max_batch_size
+ optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ optimizer.parameter_map[
+ "max_workspace_size_bytes"].i = max_workspace_size_bytes
+ optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
+ optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
+ if cached_engine_batch_sizes:
+ if not isinstance(cached_engine_batch_sizes, list):
+ raise TypeError("cached_engine_batch_sizes should be a list.")
+ if len(cached_engine_batch_sizes) > maximum_cached_engines:
+ raise ValueError("cached_engine_batch_sizes should not contain more than "
+ "maximum_cached_engines items.")
+ optimizer.parameter_map["cached_engine_batches"].list.i.extend(
+ cached_engine_batch_sizes)
+ return rewriter_cfg
def create_inference_graph(input_graph_def,
outputs,
max_batch_size=1,
max_workspace_size_bytes=2 << 20,
- precision_mode="FP32",
+ precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=None):
+ cached_engine_batch_sizes=None,
+ input_saved_model_dir=None,
+ input_saved_model_tags=None,
+ output_saved_model_dir=None,
+ session_config=None):
"""Python wrapper for the TRT transformation.
Args:
- input_graph_def: GraphDef object containing a model to be transformed.
- outputs: list of tensors or node names for the model outputs.
- max_batch_size: max size for the input batch
- max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
- precision_mode: one of 'FP32', 'FP16' and 'INT8'
+ input_graph_def: a GraphDef object containing a model to be transformed. If
+ set to None, the graph will be read from the SavedModel loaded from
+ input_saved_model_dir.
+ outputs: list of tensors or node names for the model outputs. Only used when
+ input_graph_def is not None.
+ max_batch_size: max size for the input batch.
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph to
be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
- cached_engine_batches: batch sizes used to pre-create cached engines.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batch_sizes: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be smaller than maximum_cached_engines, and the dynamic TRT op will
+ use this list to determine the batch sizes of the cached engines, instead
+ of making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+ input_saved_model_dir: the directory to load the SavedModel which contains
+ the input graph to transforms. Used only when input_graph_def is None.
+ input_saved_model_tags: list of tags to load the SavedModel.
+ output_saved_model_dir: if not None, construct a SavedModel using the
+ returned GraphDef and save it to the specified directory. This option only
+ works when the input graph is loaded from a SavedModel, i.e. when
+ input_saved_model_dir is specified and input_graph_def is None.
+ session_config: the ConfigProto used to create a Session. If not specified,
+ a default ConfigProto will be used.
Returns:
- New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+ A GraphDef transformed from input_graph_def (or the SavedModel graph def
+ loaded from input_saved_model_dir, if input_graph_def is not present), where
+ all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
+ function is added for each of the subgraphs.
+
+ If is_dynamic_op is True, each TRTEngineOp will contain a serialized
+ subgraph GraphDef, which will be converted to a TRT engine at execution time
+ and the TRT engine will be cached for future usage. A new TRT engine will be
+ created each time when none of the cached engines match the input shapes. If
+ it fails to execute the TRT engine or the number of cached engines reaches
+ maximum_cached_engines, the op will fall back to call the corresponding TF
+ function.
+
+ If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
+ engine created from the corresponding subgraph. No more engines will be
+ created on the fly, and the op will fall back to call the corresponding TF
+ function when it fails to execute the engine.
Raises:
- ValueError: if the provided precision mode is invalid.
- RuntimeError: if the returned status message is malformed.
+ ValueError: if the combination of the parameters is invalid.
+ RuntimeError: if the TensorRT library version is incompatible.
"""
- supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
- if precision_mode.upper() not in supported_precision_modes:
- raise ValueError(("precision mode '{}' is not supported."
- "It should be one of {}").format(
- precision_mode, "{'FP32', 'FP16', 'INT8'}"))
- mode = supported_precision_modes[precision_mode.upper()]
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
version_mismatch = False
@@ -101,61 +225,111 @@ def create_inference_graph(input_graph_def,
tf_logging.info("Running against TensorRT version %s" % ".".join(
[str(x) for x in loaded_version]))
- def py2bytes(inp):
- return inp
+ if session_config is None:
+ session_config = config_pb2.ConfigProto()
+
+ if input_saved_model_tags is None:
+ input_saved_model_tags = [tag_constants.SERVING]
+ saved_model_loader = None
+ grappler_meta_graph_def = None
- def py3bytes(inp):
- return inp.encode("utf-8", errors="surrogateescape")
+ if input_graph_def is None:
+ # Read from SavedModel and freeze the graph if necessary.
+ if input_saved_model_dir is None:
+ raise ValueError("input_graph_def and input_saved_model_dir cannot be "
+ "both None")
+ with ops.Graph().as_default():
+ with session.Session(config=session_config) as sess:
+ saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir)
+ input_meta_graph_def = saved_model_loader.load(sess,
+ input_saved_model_tags)
+ output_node_names = set()
- def py2string(inp):
- return inp
+ def _gather_names(tensor_info):
+ """Get the node names from a TensorInfo."""
+ return set(
+ [tensor_info[key].name.split(":")[0] for key in tensor_info])
- def py3string(inp):
- return inp.decode("utf-8")
+ # Get input and outputs from all SignatureDef.
+ for key in input_meta_graph_def.signature_def:
+ signature_def = input_meta_graph_def.signature_def[key]
+ output_node_names.update(_gather_names(signature_def.inputs))
+ output_node_names.update(_gather_names(signature_def.outputs))
- if _six.PY2:
- to_bytes = py2bytes
- to_string = py2string
+ # Freeze the variables in the SavedModel graph and copy the frozen
+ # graph over.
+ frozen_graph_def = graph_util.convert_variables_to_constants(
+ sess, sess.graph.as_graph_def(add_shapes=True),
+ list(output_node_names))
+ grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
+
+ # Copy the collections that are not variables.
+ for key in input_meta_graph_def.collection_def:
+ # TODO(laigd): currently we use the collection key to filter out
+ # collections that depend on variable ops, but this may miss some
+ # other user-defined collections. A better way would be to use
+ # CollectionDef::NodeList for the filtering.
+ if key not in [
+ "variables", "local_variables", "model_variables",
+ "trainable_variables", "train_op", "table_initializer"
+ ]:
+ grappler_meta_graph_def.collection_def[key].CopyFrom(
+ input_meta_graph_def.collection_def[key])
+
+ # Copy other information.
+ grappler_meta_graph_def.meta_info_def.CopyFrom(
+ input_meta_graph_def.meta_info_def)
+ for key in input_meta_graph_def.signature_def:
+ grappler_meta_graph_def.signature_def[key].CopyFrom(
+ input_meta_graph_def.signature_def[key])
+ # TODO(laigd): maybe add back AssetFileDef.
else:
- to_bytes = py3bytes
- to_string = py3string
-
- # Create MetaGraphDef
- graph = ops.Graph()
- with graph.as_default():
- importer.import_graph_def(input_graph_def, name="")
- meta_graph = saver.export_meta_graph(
- graph_def=graph.as_graph_def(), graph=graph)
- if outputs:
- output_collection = meta_graph_pb2.CollectionDef()
- output_list = output_collection.node_list.value
- for i in outputs:
- if isinstance(i, ops.Tensor):
- output_list.append(to_bytes(i.name))
- else:
- output_list.append(to_bytes(i))
- meta_graph.collection_def["train_op"].CopyFrom(output_collection)
+ if output_saved_model_dir is not None:
+ raise ValueError("output_saved_model_dir cannot be set when "
+ "input_graph_def is set")
+ # Create MetaGraphDef from input graph.
+ graph = ops.Graph()
+ with graph.as_default():
+ importer.import_graph_def(input_graph_def, name="")
+ grappler_meta_graph_def = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
+ if outputs:
+ output_collection = meta_graph_pb2.CollectionDef()
+ output_list = output_collection.node_list.value
+ for i in outputs:
+ if isinstance(i, ops.Tensor):
+ output_list.append(_to_bytes(i.name))
+ else:
+ output_list.append(_to_bytes(i))
+ # TODO(laigd): use another key as the outputs are really not train_op.
+ grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
+ output_collection)
# Create RewriterConfig.
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- optimizer = rewriter_cfg.custom_optimizers.add()
- optimizer.name = "TensorRTOptimizer"
- optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
- optimizer.parameter_map["max_batch_size"].i = max_batch_size
- optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
- optimizer.parameter_map[
- "max_workspace_size_bytes"].i = max_workspace_size_bytes
- optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode)
- optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
- if cached_engine_batches:
- if not isinstance(cached_engine_batches, list):
- raise TypeError("cached_engine_batches should be a list.")
- optimizer.parameter_map["cached_engine_batches"].list.i.extend(
- cached_engine_batches)
+ rewriter_cfg = tensorrt_rewriter_config(
+ max_batch_size, max_workspace_size_bytes, precision_mode,
+ minimum_segment_size, is_dynamic_op, maximum_cached_engines,
+ cached_engine_batch_sizes)
+
+ # Run Grappler.
+ transformed_graph_def = tf_optimizer.OptimizeGraph(
+ rewriter_cfg, grappler_meta_graph_def, graph_id=b"tf_graph")
- return tf_optimizer.OptimizeGraph(
- rewriter_cfg, meta_graph, graph_id=b"tf_graph")
+ # Optionally write the transformed graphdef as SavedModel.
+ if output_saved_model_dir is not None:
+ saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
+ with ops.Graph().as_default():
+ importer.import_graph_def(transformed_graph_def, name="")
+ with session.Session(config=session_config) as sess:
+ saved_model_builder.add_meta_graph_and_variables(
+ sess,
+ input_saved_model_tags,
+ signature_def_map=grappler_meta_graph_def.signature_def)
+ # Ignore other meta graphs from the input SavedModel.
+ saved_model_builder.save()
+
+ return transformed_graph_def
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
@@ -164,22 +338,13 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
Args:
calibration_graph_def: the calibration GraphDef object with calibration data
is_dynamic_op: whether to create dynamic static engines from calibration
+
Returns:
New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
Raises:
RuntimeError: if the returned status message is malformed.
"""
- def py2string(inp):
- return inp
-
- def py3string(inp):
- return inp.decode("utf-8")
-
- if _six.PY2:
- to_string = py2string
- else:
- to_string = py3string
is_calib_graph = False
for n in calibration_graph_def.node:
if n.op == "TRTEngineOp":
@@ -190,7 +355,7 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
return None
graph_str = calibration_graph_def.SerializeToString()
out = calib_convert(graph_str, is_dynamic_op)
- status = to_string(out[0])
+ status = _to_string(out[0])
output_graph_def_string = out[1]
del graph_str # Save some memory
if len(status) < 2:
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
new file mode 100644
index 0000000000..f3a1ef0d47
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
@@ -0,0 +1,293 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.tensorrt.python import trt_convert
+# pylint: disable=unused-import
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+# pylint: enable=unused-import
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.saved_model import utils
+from tensorflow.python.tools import saved_model_utils
+
+
+class TrtConvertTest(test_util.TensorFlowTestCase):
+ """Class to test Tensorflow-TensorRT integration python API."""
+
+ def testTensorrtRewriterConfig(self):
+ """Test case for trt_convert.tensorrt_rewriter_config()."""
+ rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+ max_batch_size=128,
+ max_workspace_size_bytes=1234,
+ precision_mode="INT8",
+ minimum_segment_size=10,
+ is_dynamic_op=True,
+ maximum_cached_engines=2,
+ cached_engine_batch_sizes=[1, 128])
+ trt_optimizer = None
+ for optimizer in rewriter_cfg.custom_optimizers:
+ if optimizer.name == "TensorRTOptimizer":
+ self.assertTrue(trt_optimizer is None)
+ trt_optimizer = optimizer
+ self.assertTrue(trt_optimizer is not None)
+ for key in [
+ "minimum_segment_size", "max_batch_size", "is_dynamic_op",
+ "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines",
+ "cached_engine_batches"
+ ]:
+ self.assertTrue(key in trt_optimizer.parameter_map)
+ self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
+ self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
+ self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
+ self.assertEqual(1234,
+ trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
+ self.assertEqual(
+ trt_convert._to_bytes("INT8"),
+ trt_optimizer.parameter_map["precision_mode"].s)
+ self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
+ self.assertEqual(
+ [1, 128],
+ trt_optimizer.parameter_map["cached_engine_batches"].list.i)
+
+ def _GetConfigProto(self):
+ """Get ConfigProto for session creation."""
+ config = config_pb2.ConfigProto(
+ gpu_options=config_pb2.GPUOptions(allow_growth=True))
+ return config
+
+ def _GetGraph(self):
+ """Get the graph for testing."""
+ g = ops.Graph()
+ with g.as_default():
+ with g.device("/GPU:0"):
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=[None, 1, 1], name="input")
+ var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
+ add = inp + var.value()
+ mul = inp * add
+ add = mul + add
+ out = array_ops.identity(add, name="output")
+ return g, var, inp, out
+
+ def _GetGraphDef(self):
+ """Get the graph def for testing."""
+ g, var, _, _ = self._GetGraph()
+ with self.session(graph=g, config=self._GetConfigProto()) as sess:
+ sess.run(var.initializer)
+ graph_def = graph_util.convert_variables_to_constants(
+ sess, g.as_graph_def(add_shapes=True), ["output"])
+ node_name_to_op = {node.name: node.op for node in graph_def.node}
+ self.assertEqual({
+ "v1": "Const",
+ "v1/read": "Identity",
+ "input": "Placeholder",
+ "add": "Add",
+ "mul": "Mul",
+ "add_1": "Add",
+ "output": "Identity"
+ }, node_name_to_op)
+ return graph_def
+
+ def _WriteInputSavedModel(self, input_saved_model_dir):
+ """Write the saved model as an input for testing."""
+ g, var, inp, out = self._GetGraph()
+ signature_def = signature_def_utils.build_signature_def(
+ inputs={"myinput": utils.build_tensor_info(inp)},
+ outputs={"myoutput": utils.build_tensor_info(out)},
+ method_name=signature_constants.PREDICT_METHOD_NAME)
+ saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
+ with self.session(graph=g, config=self._GetConfigProto()) as sess:
+ sess.run(var.initializer)
+ saved_model_builder.add_meta_graph_and_variables(
+ sess, [tag_constants.SERVING],
+ signature_def_map={"mypredict": signature_def})
+ saved_model_builder.save()
+
+ def _TestCreateInferenceGraph(self,
+ input_saved_model_dir=None,
+ output_saved_model_dir=None):
+ """General method to test trt_convert.create_inference_graph()."""
+ input_graph_def = None if input_saved_model_dir else self._GetGraphDef()
+ output_graph_def = trt_convert.create_inference_graph(
+ input_graph_def, ["output"],
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+ graph_defs_to_verify = [output_graph_def]
+ if output_saved_model_dir is not None:
+ saved_model_graph_def = saved_model_utils.get_meta_graph_def(
+ output_saved_model_dir, tag_constants.SERVING).graph_def
+ self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
+ graph_defs_to_verify.append(saved_model_graph_def)
+
+ for graph_def in graph_defs_to_verify:
+ node_name_to_op = {node.name: node.op for node in graph_def.node}
+ self.assertEqual({
+ "input": "Placeholder",
+ "my_trt_op_0": "TRTEngineOp",
+ "output": "Identity"
+ }, node_name_to_op)
+
+ def testCreateInferenceGraph_BasicConversion(self):
+ """Test case for trt_convert.create_inference_graph()."""
+ if not trt_convert.is_tensorrt_enabled():
+ return
+
+ # Use GraphDef as input.
+ self._TestCreateInferenceGraph()
+
+ # Use SavedModel as input.
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir1")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir1")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ self._TestCreateInferenceGraph(input_saved_model_dir,
+ output_saved_model_dir)
+
+ def _TestRun(self, sess, batch_size, expect_engine_is_run):
+ trt_convert.clear_test_values("")
+ result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
+ self.assertAllEqual([[[4.0]]] * batch_size, result)
+ execute_engine_test_value = ("done" if expect_engine_is_run else "")
+ execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
+ self.assertEqual(execute_engine_test_value,
+ trt_convert.get_test_value("my_trt_op_0:ExecuteTrtEngine"))
+ self.assertEqual(
+ execute_native_segment_test_value,
+ trt_convert.get_test_value("my_trt_op_0:ExecuteNativeSegment"))
+
+ def testCreateInferenceGraph_MinimumSegmentSize(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ output_graph_def = trt_convert.create_inference_graph(
+ self._GetGraphDef(), ["output"],
+ minimum_segment_size=5,
+ is_dynamic_op=False)
+ node_name_to_op = {node.name: node.op for node in output_graph_def.node}
+ self.assertEqual({
+ "v1/read": "Const",
+ "input": "Placeholder",
+ "add": "Add",
+ "mul": "Mul",
+ "add_1": "Add",
+ "output": "Identity"
+ }, node_name_to_op)
+
+ def testCreateInferenceGraph_DynamicOp(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ trt_convert.enable_test_value()
+
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir2")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ output_graph_def = trt_convert.create_inference_graph(
+ None,
+ None,
+ is_dynamic_op=True,
+ maximum_cached_engines=2,
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+
+ # Test the output GraphDef.
+ with ops.Graph().as_default():
+ importer.import_graph_def(output_graph_def, name="")
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ # Run with batch size 1, a new engine is created and cached.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, a new engine is created and cached.
+ self._TestRun(sess, 2, True)
+ # Run with batch size 3, since the number of cached engines has reached
+ # the max, it should fall back to TF function.
+ self._TestRun(sess, 3, False)
+
+ # Test the output SavedModel
+ with ops.Graph().as_default():
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
+ # Run with batch size 1, a new engine is created and cached.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, a new engine is created and cached.
+ self._TestRun(sess, 2, True)
+ # Run with batch size 3, since the number of cached engines has reached
+ # the max, it should fall back to TF function.
+ self._TestRun(sess, 3, False)
+
+ def testCreateInferenceGraph_StaticOp(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ trt_convert.enable_test_value()
+
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ output_graph_def = trt_convert.create_inference_graph(
+ None,
+ None,
+ max_batch_size=1,
+ is_dynamic_op=False,
+ maximum_cached_engines=2, # This is noop, added just for testing.
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+
+ # Test the output GraphDef.
+ with ops.Graph().as_default():
+ importer.import_graph_def(output_graph_def, name="")
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ # Run with batch size 1, the default engine embedded in the graphdef
+ # will be used.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, which exceed the max_batch_size, it should fall
+ # back to TF function.
+ self._TestRun(sess, 2, False)
+
+ # Test the output SavedModel
+ with ops.Graph().as_default():
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
+ # Run with batch size 1, the default engine embedded in the graphdef
+ # will be used.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, which exceed the max_batch_size, it should fall
+ # back to TF function.
+ self._TestRun(sess, 2, False)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
index d8f97bfbbc..a9425864dd 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.cc
@@ -27,12 +27,16 @@ namespace tensorflow {
namespace tensorrt {
// std::align is not supported, so this method mimic its behavior.
-void* Align(size_t alignment, size_t size, void*& ptr, size_t& space) {
- QCHECK_GT(alignment, 0) << "alignment must be greater than 0.";
+//
+// NOTE(aaroey): according to the TensorRT API,
+// nvinfer1::IGpuAllocator::allocate() uses uint64_t type for size and alignment
+// parameters, so here we use the same type to make it compatible.
+void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space) {
+ QCHECK_GT(alignment, 0ul) << "alignment must be greater than 0.";
QCHECK_EQ(0, alignment & (alignment - 1)) << "Alignment must be power of 2.";
- QCHECK_GT(size, 0) << "size must be greater than 0.";
+ QCHECK_GT(size, 0ul) << "size must be greater than 0.";
QCHECK(ptr) << "ptr must not be nullptr.";
- QCHECK_GT(space, 0) << "space must be greater than 0.";
+ QCHECK_GT(space, 0ul) << "space must be greater than 0.";
const uintptr_t ptr_val = reinterpret_cast<uintptr_t>(ptr);
QCHECK_GE(ptr_val + space, ptr_val) << "Provided space overflows.";
@@ -67,12 +71,16 @@ void TRTCudaAllocator::free(void* memory) { cudaFree(memory); }
void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment,
uint32_t flags) {
+ if (size == 0) return nullptr;
// WAR for allocator alignment requirement. Certain cuda API calls require GPU
// memory with alignemtn to cudaDeviceProp::textureAlignment.
// See issue #20856
alignment = 512;
assert((alignment & (alignment - 1)) == 0); // zero or a power of 2.
- size_t total_size = size + alignment;
+ uint64_t total_size = size + alignment;
+ // TODO(aaroey): AllocateRaw takes size_t size as input, so it'll produce
+ // unexpected result when TRT tries to allocate more bytes than size_t can
+ // carry. Fix this.
void* mem = allocator_->AllocateRaw(alignment, total_size);
if (!mem) return nullptr;
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator.h b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
index 6f94492083..dc9862b16c 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator.h
@@ -29,7 +29,7 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
// std::align is not supported, so this function mimic its behavior.
-void* Align(size_t alignment, size_t size, void*& ptr, size_t& space);
+void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space);
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
index f515ed03f2..ad6b1d7d4c 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_allocator_test.cc
@@ -20,11 +20,11 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-bool RunTest(const size_t alignment, const size_t size,
- const intptr_t orig_ptr_val, const size_t orig_space) {
+bool RunTest(const uint64_t alignment, const uint64_t size,
+ const intptr_t orig_ptr_val, const uint64_t orig_space) {
void* const orig_ptr = reinterpret_cast<void*>(orig_ptr_val);
void* ptr = orig_ptr;
- size_t space = orig_space;
+ uint64_t space = orig_space;
void* result = Align(alignment, size, ptr, space);
if (result == nullptr) {
EXPECT_EQ(orig_ptr, ptr);
@@ -43,24 +43,25 @@ bool RunTest(const size_t alignment, const size_t size,
}
TEST(TRTAllocatorTest, Align) {
- for (const size_t space :
- {1, 2, 3, 4, 7, 8, 9, 10, 16, 32, 511, 512, 513, 700, 12345}) {
- for (size_t alignment = 1; alignment <= space * 4; alignment *= 2) {
- for (const intptr_t ptr_val :
+ for (const uint64_t space :
+ {1ul, 2ul, 3ul, 4ul, 7ul, 8ul, 9ul, 10ul, 16ul, 32ul, 511ul, 512ul,
+ 513ul, 700ul, 12345ul, 1ul << 32}) {
+ for (uint64_t alignment = 1; alignment <= space * 4; alignment *= 2) {
+ for (const uintptr_t ptr_val :
{1ul, alignment == 1 ? 1ul : alignment - 1, alignment, alignment + 1,
alignment + (alignment / 2)}) {
if (ptr_val % alignment == 0) {
- for (const size_t size :
+ for (const uint64_t size :
{1ul, space == 1 ? 1ul : space - 1, space, space + 1}) {
EXPECT_EQ(space >= size, RunTest(alignment, size, ptr_val, space));
}
} else {
EXPECT_FALSE(RunTest(alignment, space, ptr_val, space));
- const size_t diff = alignment - ptr_val % alignment;
+ const uint64_t diff = alignment - ptr_val % alignment;
if (space > diff) {
EXPECT_TRUE(
RunTest(alignment, space - diff, ptr_val + diff, space - diff));
- for (const size_t size :
+ for (const uint64_t size :
{1ul, space - diff > 1 ? space - diff - 1 : 1ul, space - diff,
space - diff + 1, space - 1}) {
EXPECT_EQ(space - diff >= size,
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index e9ac833d55..7e9ffb05ab 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -183,6 +183,12 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
"my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
}
+ def ShouldRunTest(self, run_params):
+ """Whether to run the test."""
+ # Disable the test in fp16 mode since multiple matmul and add ops together
+ # can cause overflow.
+ return run_params.precision_mode != "FP16"
+
class PartiallyConvertedTestB(PartiallyConvertedTestA):
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 62f4e525f7..d2f65344da 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -144,14 +144,6 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
# mode, which is a bug. Re-enable this when trt library is fixed.
return not trt_test.IsQuantizationMode(run_params.precision_mode)
- def ExpectedAbsoluteTolerance(self, run_params):
- """The absolute tolerance to compare floating point results."""
- return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
-
- def ExpectedRelativeTolerance(self, run_params):
- """The relative tolerance to compare floating point results."""
- return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 090aa8bdb0..d26f260086 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -191,7 +191,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
o1 = run_graph(orig_graph, dummy_input)
o2 = run_graph(trt_graph, dummy_input)
o3 = run_graph(trt_graph, dummy_input)
@@ -206,7 +206,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
int8_calib_gdef = trt.create_inference_graph(
input_graph_def=orig_graph,
outputs=["output"],
@@ -216,7 +216,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
o4 = run_graph(fp16_graph, dummy_input)
_ = run_calibration(int8_calib_gdef, dummy_input)
int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 65ca21cf37..4f935a7665 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -30,7 +30,6 @@ from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
@@ -50,7 +49,7 @@ RunParams = namedtuple(
ConversionParams = namedtuple("ConversionParams", [
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
- "cached_engine_batches"
+ "cached_engine_batch_sizes"
])
PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -135,11 +134,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
dims[0] for dims in self._GetParamsCached().input_dims if len(dims)
]),
max_workspace_size_bytes=1 << 25,
- precision_mode=self._ToBytes(run_params.precision_mode),
+ precision_mode=run_params.precision_mode,
minimum_segment_size=2,
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,
- cached_engine_batches=None)
+ cached_engine_batch_sizes=None)
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
@@ -180,11 +179,11 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def ExpectedAbsoluteTolerance(self, run_params):
"""The absolute tolerance to compare floating point results."""
- return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
def ExpectedRelativeTolerance(self, run_params):
"""The relative tolerance to compare floating point results."""
- return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-02
def _GetParamsCached(self):
if self._trt_test_params is None:
@@ -201,23 +200,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _GetConfigProto(self, run_params, graph_state):
"""Get config proto based on specific settings."""
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- custom_op = rewriter_cfg.custom_optimizers.add()
- custom_op.name = "TensorRTOptimizer"
trt_params = self.GetConversionParams(run_params)
- custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size
- custom_op.parameter_map["max_workspace_size_bytes"].i = (
- trt_params.max_workspace_size_bytes)
- custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode
- custom_op.parameter_map["minimum_segment_size"].i = (
- trt_params.minimum_segment_size)
- custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op
- custom_op.parameter_map["maximum_cached_engines"].i = (
- trt_params.maximum_cached_engines)
- if trt_params.cached_engine_batches:
- custom_op.parameter_map["cached_engine_batches"].list.i.extend(
- trt_params.cached_engine_batches)
+ rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+ trt_params.max_batch_size, trt_params.max_workspace_size_bytes,
+ trt_params.precision_mode, trt_params.minimum_segment_size,
+ trt_params.is_dynamic_op, trt_params.maximum_cached_engines,
+ trt_params.cached_engine_batch_sizes)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
@@ -308,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
minimum_segment_size=trt_params.minimum_segment_size,
is_dynamic_op=trt_params.is_dynamic_op,
maximum_cached_engines=trt_params.maximum_cached_engines,
- cached_engine_batches=trt_params.cached_engine_batches)
+ cached_engine_batch_sizes=trt_params.cached_engine_batch_sizes)
def _WriteGraph(self, run_params, gdef, graph_state):
if graph_state == GraphState.ORIGINAL:
@@ -426,6 +414,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
if not self.ShouldRunTest(run_params):
return
assert run_params.precision_mode in PRECISION_MODES
+ np.random.seed(12345)
params = self._GetParamsCached()
input_gdef = params.gdef
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
index 84e36146d5..832d34d60d 100644
--- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
@@ -63,7 +63,7 @@ class SkipGramOpsTest(test.TestCase):
(b"jumps", b"brown"),
(b"jumps", b"fox"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -94,7 +94,7 @@ class SkipGramOpsTest(test.TestCase):
(b"jumps", b"fox"),
(b"jumps", b"jumps"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -105,7 +105,7 @@ class SkipGramOpsTest(test.TestCase):
# If emit_self_as_target is False (default), output will be empty.
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, tokens.eval().size)
self.assertEqual(0, labels.eval().size)
@@ -117,7 +117,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"quick"),
(b"brown", b"brown"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -134,7 +134,7 @@ class SkipGramOpsTest(test.TestCase):
(b"brown", b"the"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -150,7 +150,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"brown"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -165,7 +165,7 @@ class SkipGramOpsTest(test.TestCase):
(b"quick", b"brown"),
(b"brown", b"quick"),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -196,7 +196,7 @@ class SkipGramOpsTest(test.TestCase):
(b"over", b"fox"),
(b"over", b"jumps"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
self.assertAllEqual(expected_labels, labels_eval)
@@ -222,7 +222,7 @@ class SkipGramOpsTest(test.TestCase):
tokens_2, labels_2 = text.skip_gram_sample(
input_tensor, min_skips=1, max_skips=5)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run(
[tokens_1, labels_1, tokens_2, labels_2])
@@ -244,7 +244,7 @@ class SkipGramOpsTest(test.TestCase):
(b"brown", b"fox"),
(b"fox", b"brown"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
@@ -269,7 +269,7 @@ class SkipGramOpsTest(test.TestCase):
(2, 3),
(3, 2),
])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -286,7 +286,7 @@ class SkipGramOpsTest(test.TestCase):
for min_skips, max_skips in invalid_skips:
tokens, labels = text.skip_gram_sample(
input_tensor, min_skips=min_skips, max_skips=max_skips)
- with self.test_session() as sess, self.assertRaises(
+ with self.cached_session() as sess, self.assertRaises(
errors.InvalidArgumentError):
sess.run([tokens, labels])
@@ -338,7 +338,7 @@ class SkipGramOpsTest(test.TestCase):
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
- with self.test_session():
+ with self.cached_session():
vocab_freq_table.init.run()
# No vocab_freq_table specified - output should be the same as input.
@@ -395,7 +395,7 @@ class SkipGramOpsTest(test.TestCase):
vocab_freq_table = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), -1)
- with self.test_session():
+ with self.cached_session():
vocab_freq_table.init.run()
output = skip_gram_ops._filter_input(
input_tensor=input_tensor,
@@ -464,7 +464,7 @@ class SkipGramOpsTest(test.TestCase):
(b"life", b"and"),
(b"and", b"life"),
])
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(expected_tokens, tokens.eval())
self.assertAllEqual(expected_labels, labels.eval())
@@ -510,7 +510,7 @@ class SkipGramOpsTest(test.TestCase):
(b"to", b"life"),
(b"life", b"to"),
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
lookup_ops.tables_initializer().run()
tokens_eval, labels_eval = sess.run([tokens, labels])
self.assertAllEqual(expected_tokens, tokens_eval)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index d808945334..9bbe87e301 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -191,6 +191,43 @@ class ARModel(model.TimeSeriesModel):
Note that this class can also be used to regress against time only by setting
the input_window_size to zero.
+
+ Each periodicity in the `periodicities` arg is divided by the
+ `num_time_buckets` into time buckets that are represented as features added
+ to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_time_buckets` is how often
+ the data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the time bucket it falls under. If it doesn't fall
+ under a feature's associated time bucket, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_time_buckets` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd time bucket for periodicity 9 (2nd period is 9-18 and 3rd
+ time bucket is 15-18)
+ - It's in the 2nd time bucket for periodicity 12 (2nd period is 12-24 and
+ 2nd time bucket is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timebucket#), feature value
+ P9_T1, 0 # not in first time bucket
+ P9_T2, 0 # not in second time bucket
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd time bucket
+ P12_T1, 0 # not in first time bucket
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd time bucket
+ P12_T3, 0 # not in third time bucket
"""
SQUARED_LOSS = "squared_loss"
NORMAL_LIKELIHOOD_LOSS = "normal_likelihood_loss"
@@ -208,7 +245,9 @@ class ARModel(model.TimeSeriesModel):
Args:
periodicities: periodicities of the input data, in the same units as the
- time feature. Note this can be a single value or a list of values for
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
multiple periodicities.
input_window_size: Number of past time steps of data to look at when doing
the regression.
@@ -218,21 +257,18 @@ class ARModel(model.TimeSeriesModel):
prediction_model_factory: A callable taking arguments `num_features`,
`input_window_size`, and `output_window_size` and returning a
`tf.keras.Model`. The `Model`'s `call()` takes two arguments: an input
- window and an output window, and returns a dictionary of
- predictions. See `FlatPredictionModel` for an example. Example usage:
+ window and an output window, and returns a dictionary of predictions.
+ See `FlatPredictionModel` for an example. Example usage:
- ```python
- model = ar_model.ARModel(
- periodicities=2, num_features=3,
- prediction_model_factory=functools.partial(
- FlatPredictionModel,
- hidden_layer_sizes=[10, 10]))
- ```
+ ```python model = ar_model.ARModel( periodicities=2, num_features=3,
+ prediction_model_factory=functools.partial( FlatPredictionModel,
+ hidden_layer_sizes=[10, 10])) ```
The default model computes predictions as a linear function of flattened
input and output windows.
num_time_buckets: Number of buckets into which to divide (time %
- periodicity) for generating time based features.
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
loss: Loss function to use for training. Currently supported values are
SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
@@ -240,10 +276,9 @@ class ARModel(model.TimeSeriesModel):
observations and predictions, while the training loss is computed on
normalized data (if input statistics are available).
exogenous_feature_columns: A list of `tf.feature_column`s (for example
- `tf.feature_column.embedding_column`) corresponding to exogenous
- features which provide extra information to the model but are not part
- of the series to be predicted. Passed to
- `tf.feature_column.input_layer`.
+ `tf.feature_column.embedding_column`) corresponding to
+ features which provide extra information to the model but are not part
+ of the series to be predicted.
"""
self._model_factory = prediction_model_factory
self.input_window_size = input_window_size
@@ -264,10 +299,10 @@ class ARModel(model.TimeSeriesModel):
elif (not isinstance(periodicities, list) and
not isinstance(periodicities, tuple)):
periodicities = [periodicities]
- self._periods = [int(p) for p in periodicities]
- for p in self._periods:
+ self._periodicities = [int(p) for p in periodicities]
+ for p in self._periodicities:
assert p > 0
- assert len(self._periods) or self.input_window_size
+ assert len(self._periodicities) or self.input_window_size
assert output_window_size > 0
def initialize_graph(self, input_statistics=None):
@@ -364,9 +399,9 @@ class ARModel(model.TimeSeriesModel):
input_feature_size = 0
output_window_features = []
output_feature_size = 0
- if self._periods:
+ if self._periodicities:
_, time_features = self._compute_time_features(times)
- num_time_features = self._buckets * len(self._periods)
+ num_time_features = self._buckets * len(self._periodicities)
time_features = array_ops.reshape(
time_features,
[batch_size,
@@ -849,12 +884,12 @@ class ARModel(model.TimeSeriesModel):
def _compute_time_features(self, time):
"""Compute some features on the time value."""
batch_size = array_ops.shape(time)[0]
- num_periods = len(self._periods)
+ num_periods = len(self._periodicities)
# Reshape to 3D.
periods = constant_op.constant(
- self._periods, shape=[1, 1, num_periods, 1], dtype=time.dtype)
+ self._periodicities, shape=[1, 1, num_periods, 1], dtype=time.dtype)
time = array_ops.reshape(time, [batch_size, -1, 1, 1])
- window_offset = time / self._periods
+ window_offset = time / self._periodicities
# Cast to appropriate type and scale to [0, 1) range
mod = (math_ops.cast(time % periods, self.dtype) * self._buckets /
math_ops.cast(periods, self.dtype))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 0ddc4b4144..af68aa03cf 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -30,6 +30,7 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models import s
from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filtering_postprocessor import StateInterpolatingAnomalyDetector
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.estimator.canned import optimizers
from tensorflow.python.estimator.export import export_lib
from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
@@ -386,6 +387,162 @@ class ARRegressor(TimeSeriesRegressor):
config=config)
+# TODO(b/113684821): Add detailed documentation on what the input_fn should do.
+# Add an example of making and returning a Dataset object. Determine if
+# endogenous features can be passed in as FeatureColumns. Move ARModel's loss
+# functions into a more general location.
+class LSTMAutoRegressor(TimeSeriesRegressor):
+ """An Estimator for an LSTM autoregressive model.
+
+ LSTMAutoRegressor is a window-based model, inputting fixed windows of length
+ `input_window_size` and outputting fixed windows of length
+ `output_window_size`. These two parameters must add up to the window_size
+ of data returned by the `input_fn`.
+
+ Each periodicity in the `periodicities` arg is divided by the `num_timesteps`
+ into timesteps that are represented as time features added to the model.
+
+ A good heuristic for picking an appropriate periodicity for a given data set
+ would be the length of cycles in the data. For example, energy usage in a
+ home is typically cyclic each day. If the time feature in a home energy
+ usage dataset is in the unit of hours, then 24 would be an appropriate
+ periodicity. Similarly, a good heuristic for `num_timesteps` is how often the
+ data is expected to change within the cycle. For the aforementioned home
+ energy usage dataset and periodicity of 24, then 48 would be a reasonable
+ value if usage is expected to change every half hour.
+
+ Each feature's value for a given example with time t is the difference
+ between t and the start of the timestep it falls under. If it doesn't fall
+ under a feature's associated timestep, then that feature's value is zero.
+
+ For example: if `periodicities` = (9, 12) and `num_timesteps` = 3, then 6
+ features would be added to the model, 3 for periodicity 9 and 3 for
+ periodicity 12.
+
+ For an example data point where t = 17:
+ - It's in the 3rd timestep for periodicity 9 (2nd period is 9-18 and 3rd
+ timestep is 15-18)
+ - It's in the 2nd timestep for periodicity 12 (2nd period is 12-24 and
+ 2nd timestep is between 16-20).
+
+ Therefore the 6 added features for this row with t = 17 would be:
+
+ # Feature name (periodicity#_timestep#), feature value
+ P9_T1, 0 # not in first timestep
+ P9_T2, 0 # not in second timestep
+ P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd timestep
+ P12_T1, 0 # not in first timestep
+ P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd timestep
+ P12_T3, 0 # not in third timestep
+
+ Example Code:
+
+ ```python
+ extra_feature_columns = (
+ feature_column.numeric_column("exogenous_variable"),
+ )
+
+ estimator = LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=5,
+ model_dir="/path/to/model/dir",
+ num_features=1,
+ extra_feature_columns=extra_feature_columns,
+ num_timesteps=50,
+ num_units=10,
+ optimizer=tf.train.ProximalAdagradOptimizer(...))
+
+ # Input builders
+ def input_fn_train():
+ return {
+ "times": tf.range(15)[None, :],
+ "values": tf.random_normal(shape=[1, 15, 1])
+ }
+ estimator.train(input_fn=input_fn_train, steps=100)
+
+ def input_fn_eval():
+ pass
+ metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)
+
+ def input_fn_predict():
+ pass
+ predictions = estimator.predict(input_fn=input_fn_predict)
+ ```
+ """
+
+ def __init__(self,
+ periodicities,
+ input_window_size,
+ output_window_size,
+ model_dir=None,
+ num_features=1,
+ extra_feature_columns=None,
+ num_timesteps=10,
+ loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,
+ num_units=128,
+ optimizer="Adam",
+ config=None):
+ """Initialize the Estimator.
+
+ Args:
+ periodicities: periodicities of the input data, in the same units as the
+ time feature (for example 24 if feeding hourly data with a daily
+ periodicity, or 60 * 24 if feeding minute-level data with daily
+ periodicity). Note this can be a single value or a list of values for
+ multiple periodicities.
+ input_window_size: Number of past time steps of data to look at when doing
+ the regression.
+ output_window_size: Number of future time steps to predict. Note that
+ setting this value to > 1 empirically seems to give a better fit.
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator
+ to continue training a previously saved model.
+ num_features: The dimensionality of the time series (default value is
+ one for univariate, more than one for multivariate).
+ extra_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to features which
+ provide extra information to the model but are not part of the series to
+ be predicted.
+ num_timesteps: Number of buckets into which to divide (time %
+ periodicity). This value multiplied by the number of periodicities is
+ the number of time features added to the model.
+ loss: Loss function to use for training. Currently supported values are
+ SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for
+ NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For
+ SQUARED_LOSS, the evaluation loss is reported based on un-scaled
+ observations and predictions, while the training loss is computed on
+ normalized data.
+ num_units: The size of the hidden state in the encoder and decoder LSTM
+ cells.
+ optimizer: string, `tf.train.Optimizer` object, or callable that defines
+ the optimizer algorithm to use for training. Defaults to the Adam
+ optimizer with a learning rate of 0.01.
+ config: Optional `estimator.RunConfig` object to configure the runtime
+ settings.
+ """
+ optimizer = optimizers.get_optimizer_instance(
+ optimizer, learning_rate=0.01)
+ model = ar_model.ARModel(
+ periodicities=periodicities,
+ input_window_size=input_window_size,
+ output_window_size=output_window_size,
+ num_features=num_features,
+ exogenous_feature_columns=extra_feature_columns,
+ num_time_buckets=num_timesteps,
+ loss=loss,
+ prediction_model_factory=functools.partial(
+ ar_model.LSTMPredictionModel, num_units=num_units))
+ state_manager = state_management.FilteringOnlyStateManager()
+ super(LSTMAutoRegressor, self).__init__(
+ model=model,
+ state_manager=state_manager,
+ optimizer=optimizer,
+ model_dir=model_dir,
+ config=config,
+ head_type=ts_head_lib.OneShotPredictionHead)
+
+
class StateSpaceRegressor(TimeSeriesRegressor):
"""An Estimator for general state space models."""
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 461fe22210..6ec7184c68 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -216,6 +216,50 @@ class TimeSeriesRegressorTest(test.TestCase):
exogenous_feature_columns=exogenous_feature_columns)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype)
+ def test_structural_ensemble_numpy_input(self):
+ numpy_data = {"times": numpy.arange(50),
+ "values": numpy.random.normal(size=[50])}
+ estimators.StructuralEnsembleRegressor(
+ num_features=1, periodicities=[], model_dir=self.get_temp_dir(),
+ config=_SeedRunConfig()).train(
+ input_pipeline.WholeDatasetInputFn(
+ input_pipeline.NumpyReader(numpy_data)),
+ steps=1)
+
+ def test_ar_lstm_regressor(self):
+ dtype = dtypes.float32
+ model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ exogenous_feature_columns = (
+ feature_column.numeric_column("exogenous"),
+ )
+ estimator = estimators.LSTMAutoRegressor(
+ periodicities=10,
+ input_window_size=10,
+ output_window_size=6,
+ model_dir=model_dir,
+ num_features=1,
+ extra_feature_columns=exogenous_feature_columns,
+ num_units=10,
+ config=_SeedRunConfig())
+ times = numpy.arange(20, dtype=numpy.int64)
+ values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ features = {
+ feature_keys.TrainEvalFeatures.TIMES: times,
+ feature_keys.TrainEvalFeatures.VALUES: values,
+ "exogenous": exogenous
+ }
+ train_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
+ batch_size=16, window_size=16)
+ eval_input_fn = input_pipeline.RandomWindowInputFn(
+ input_pipeline.NumpyReader(features), shuffle_seed=3, num_threads=1,
+ batch_size=16, window_size=16)
+ estimator.train(input_fn=train_input_fn, steps=1)
+ evaluation = estimator.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ self.assertAllEqual(evaluation["loss"], evaluation["average_loss"])
+ self.assertAllEqual([], evaluation["loss"].shape)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
index 9b593fecbb..03da2b82e5 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py
@@ -896,8 +896,8 @@ class InputStatisticsFromMiniBatch(object):
statistics.total_observation_count,
math_ops.cast(
gen_math_ops.round(
- math_ops.cast(auxiliary_variables.max_time_seen -
- statistics.start_time + 1, self._dtype) /
+ math_ops.cast(max_time_seen_assign -
+ start_time_update + 1, self._dtype) /
inter_observation_duration_estimate), dtypes.int64))
per_chunk_stat_updates = control_flow_ops.group(
overall_feature_mean_update, overall_feature_var_update,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
index 53d7340e85..a77c507d9b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/filtering_postprocessor_test.py
@@ -61,7 +61,7 @@ class FilteringStepPostprocessorTest(test.TestCase):
expected_state = [[[80.], [20.]],
[1., 6.],
[-1, -2]]
- with self.test_session():
+ with self.cached_session():
for interpolated, expected in zip(interpolated_state, expected_state):
self.assertAllClose(expected, interpolated.eval())
self.assertGreater(0., updated_outputs["anomaly_score"][0].eval())
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
index 57f29f3c7f..f636126a33 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/kalman_filter_test.py
@@ -98,7 +98,7 @@ class MultivariateTests(test.TestCase):
observation_model=observation_model,
predicted_observations=(observed_mean, observed_var),
observation_noise=observation_noise_covariance)
- with self.test_session() as session:
+ with self.cached_session() as session:
evaled_state = numpy.array([[1., 1., 1., 1.]])
evaled_state_var = numpy.eye(4)[None]
for i in range(500):
@@ -136,7 +136,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_observed_from_state(self):
"""Compare observation mean and noise to hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[2., 1.]])
state_var = constant_op.constant([[[4., 0.], [0., 3.]]])
observed_mean, observed_var = self.kalman_filter.observed_from_state(
@@ -171,7 +171,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
observation_model=observation_model,
predicted_observations=predicted_observations,
observation_noise=observation_noise))
- with self.test_session() as session:
+ with self.cached_session() as session:
evaled_state, evaled_state_var = session.run([state, state_var])
for _ in range(300):
evaled_state, evaled_state_var = session.run(
@@ -231,7 +231,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_predict_state_mean(self):
"""Compare state mean transitions with simple hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.]])
state = self.kalman_filter.predict_state_mean(
state, self.transition_fn([1]))
@@ -245,7 +245,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
def test_predict_state_var(self):
"""Compare a variance transition with simple hand-computed values."""
- with self.test_session():
+ with self.cached_session():
state_var = constant_op.constant([[[1., 0.], [0., 2.]]])
state_var = self.kalman_filter.predict_state_var(
state_var, self.transition_fn([1]), self.power_sum_fn([1]))
@@ -259,7 +259,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
Tests that correct values have high probability and incorrect values
have low probability when there is low uncertainty.
"""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.]])
state_var = constant_op.constant([[[0.0001, 0.], [0., 0.0001]]])
observation = constant_op.constant([[
@@ -289,7 +289,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
self.assertGreater(first_log_prob.eval()[0], numpy.log(0.99))
def test_predict_n_ahead_mean(self):
- with self.test_session():
+ with self.cached_session():
original_state = constant_op.constant([[4., 2.]])
n = 5
iterative_state = original_state
@@ -304,7 +304,7 @@ class KalmanFilterNonBatchTest(test.TestCase):
self.transition_fn([1]))
def test_predict_n_ahead_var(self):
- with self.test_session():
+ with self.cached_session():
original_var = constant_op.constant([[[2., 3.], [4., 5.]]])
n = 5
iterative_var = original_var
@@ -330,7 +330,7 @@ class KalmanFilterBatchTest(test.TestCase):
Tests that correct values have high probability and incorrect values
have low probability when there is low uncertainty.
"""
- with self.test_session():
+ with self.cached_session():
state = constant_op.constant([[4., 2.], [5., 3.], [6., 4.]])
state_var = constant_op.constant(3 * [[[0.0001, 0.], [0., 0.0001]]])
observation = constant_op.constant([
@@ -378,7 +378,7 @@ class KalmanFilterBatchTest(test.TestCase):
self.assertLess(third_log_prob.sum(), numpy.log(0.01))
def test_predict_n_ahead_mean(self):
- with self.test_session():
+ with self.cached_session():
kf = kalman_filter.KalmanFilter()
transition_fn, _ = _powers_and_sums_from_transition_matrix(
state_transition=STATE_TRANSITION,
@@ -396,7 +396,7 @@ class KalmanFilterBatchTest(test.TestCase):
self.assertAllClose(state2.eval()[2], batch_eval[2])
def test_predict_n_ahead_var(self):
- with self.test_session():
+ with self.cached_session():
kf = kalman_filter.KalmanFilter()
transition_fn, power_sum_fn = _powers_and_sums_from_transition_matrix(
state_transition=STATE_TRANSITION,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
index 951c6546d5..d04c721007 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py
@@ -909,7 +909,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
elif unbroadcasted_shape.ndims == 2:
# Unbroadcasted shape [num features x state dimension]
broadcasted_model = array_ops.tile(
- array_ops.expand_dims(unbroadcasted_model, dim=0),
+ array_ops.expand_dims(unbroadcasted_model, axis=0),
[array_ops.shape(times)[0], 1, 1])
elif unbroadcasted_shape.ndims == 3:
broadcasted_model = unbroadcasted_model
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
index c2eaa78493..80126ac786 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py
@@ -96,7 +96,7 @@ class ConstructionTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -114,7 +114,7 @@ class ConstructionTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -144,7 +144,7 @@ class GapTests(test.TestCase):
state=math_utils.replicate_state(
start_state=random_model.get_start_state(),
batch_size=array_ops.shape(times)[0]))
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -250,7 +250,7 @@ class StateSpaceEquivalenceTests(test.TestCase):
self.assertAllClose(combined_value, split_predict[prediction_key])
def _equivalent_to_single_model_test_template(self, model_generator):
- with self.test_session() as session:
+ with self.cached_session() as session:
random_model = RandomStateSpaceModel(
state_dimension=5,
state_noise_dimension=4,
@@ -374,7 +374,7 @@ class PredictionTests(test.TestCase):
math_utils.replicate_state(
start_state=random_model.get_start_state(), batch_size=1)
})
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
predicted_mean = prediction_dict["mean"].eval()
predicted_covariance = prediction_dict["covariance"].eval()
@@ -404,7 +404,7 @@ class PredictionTests(test.TestCase):
feature_keys.PredictionFeatures.TIMES: [[5, 7, 8]],
feature_keys.PredictionFeatures.STATE_TUPLE: model_outputs.end_state
})
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
predicted_mean = predictions["mean"].eval()
predicted_covariance = predictions["covariance"].eval()
@@ -428,7 +428,7 @@ class ExogenousTests(test.TestCase):
state=[
array_ops.ones(shape=[1, 5]), original_covariance[None], [0]
])
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
evaled_new_covariance, evaled_original_covariance = session.run(
[new_covariance[0], original_covariance])
@@ -454,7 +454,7 @@ class ExogenousTests(test.TestCase):
-array_ops.ones(shape=[1, 5], dtype=dtype),
original_covariance[None], [0]
])
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
evaled_new_covariance, evaled_original_covariance = session.run(
[new_covariance[0], original_covariance])
@@ -519,7 +519,7 @@ class PosteriorTests(test.TestCase):
model=stub_model, data=data, true_parameters=true_params)
def test_exact_posterior_recovery_no_transition_noise(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
stub_model, data, true_params = self._get_single_model()
input_fn = input_pipeline.WholeDatasetInputFn(
input_pipeline.NumpyReader(data))
@@ -559,7 +559,7 @@ class PosteriorTests(test.TestCase):
posterior_times)
def test_chained_exact_posterior_recovery_no_transition_noise(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
stub_model, data, true_params = self._get_single_model()
chunk_size = 10
input_fn = test_utils.AllWindowInputFn(
@@ -748,7 +748,7 @@ class MultivariateTests(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
index 84885d5c9a..e8875f4eb9 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma_test.py
@@ -46,7 +46,7 @@ class MakeModelTest(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -65,7 +65,7 @@ class MakeModelTest(test.TestCase):
},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
@@ -85,7 +85,7 @@ class MakeModelTest(test.TestCase):
TrainEvalFeatures.VALUES: constant_op.constant([[[1.], [2.]]])},
mode=estimator_lib.ModeKeys.TRAIN)
initializer = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([initializer])
outputs.loss.eval()
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 298ffc1ded..4e0b61227e 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -80,7 +80,7 @@ tf_gen_op_libs(
"tpu_embedding_ops",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
],
@@ -99,7 +99,7 @@ tf_custom_op_library(
"ops/tpu_embedding_ops.cc",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_config_proto_cc",
+ "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
"//tensorflow/core:lib_proto_parsing",
],
)
@@ -351,7 +351,7 @@ tf_py_test(
tf_py_test(
name = "topology_test",
- size = "small",
+ size = "medium",
srcs = ["python/tpu/topology_test.py"],
additional_deps = [
":tpu",
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index ea8e0e00ed..87e3a5946c 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -125,4 +125,24 @@ output: The sum of all the distributed inputs.
T: The type of elements to be summed.
)doc");
+REGISTER_OP("CollectivePermute")
+ .Input("input: T")
+ .Input("source_target_pairs: int32")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+An Op to permute tensors across replicated TPU instances. Each instance
+supplies its own input.
+
+For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
+source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
+`[D, A, B, C]`.
+
+input: The local input to be permuted. Currently only supports float and
+ bfloat16.
+source_target_pairs: A tensor with shape [num_pairs, 2].
+output: The permuted input.
+T: The type of elements to be exchanged.
+)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc
index 15a2bb17a9..285e11d92d 100644
--- a/tensorflow/contrib/tpu/ops/replication_ops.cc
+++ b/tensorflow/contrib/tpu/ops/replication_ops.cc
@@ -24,9 +24,11 @@ using shape_inference::ShapeHandle;
REGISTER_OP("TPUReplicateMetadata")
.Attr("num_replicas: int >= 0")
+ .Attr("num_cores_per_replica: int = 1")
.Attr("topology: string = \"\"")
.Attr("use_tpu: bool = true")
.Attr("device_assignment: list(int) = []")
+ // Deprecated. Use num_cores_per_replica instead.
.Attr("computation_shape: list(int) = []")
.Attr("host_compute_core: list(string) = []")
.SetShapeFn(shape_inference::UnknownShape);
@@ -93,11 +95,11 @@ REGISTER_OP("TPUCompilationResult")
REGISTER_OP("TPUReplicate")
.Attr("computation: func")
.Attr("num_replicas: int >= 1")
+ .Attr("num_cores_per_replica: int = 1")
.Attr("topology: string = \"\"")
.Attr("use_tpu: bool = true")
.Attr("device_assignment: list(int) = []")
.Attr("host_compute_core: list(string) = []")
- .Attr("computation_shape: list(int) = []")
.Attr("Tinputs: list(type) >= 0")
.Attr("Tbroadcast_inputs: list(type) >= 0")
.Attr("NumVariables: int >= 0")
@@ -114,16 +116,15 @@ Runs replicated computations on a distributed TPU system.
computation: a function containing the computation to run.
num_replicas: the number of replicas of the computation to run.
+num_cores_per_replica: the number of logical cores in each replica.
topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU
topology.
use_tpu: a bool indicating if this computation will run on TPU or CPU/GPU.
Currently, only supports a default placement (computation is placed on GPU
if one is available, and on CPU if not).
-computation_shape: a [mesh_dimension] array describing the shape of each
- computation replica in numbers of cores in the TPU mesh.
device_assignment: a flattened array with shape
- [replica] + computation_shape + [mesh_dimension] that maps the coordinates of
- logical cores in each replica of a computation to physical coordinates in
+ [replica, num_cores_per_replica, mesh_dimension] that maps the coordinates
+ of logical cores in each replica of a computation to physical coordinates in
the TPU topology.
Tinputs: the types of the arguments to 'computation'.
inputs: the inputs to 'computation', flattened, in replica-major order.
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 72d37f774c..18b98939b8 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/tpu/proto/tpu_embedding_config.pb.h"
+#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -88,12 +88,12 @@ Status GradientDescentShapes(shape_inference::InferenceContext *c) {
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_config_size();
+ int64 num_tables = config.table_descriptor_size();
if (table_id >= num_tables) {
return errors::InvalidArgument("Table id >= num_tables");
}
- int64 width = config.table_config(table_id).width();
- int64 num_rows = config.table_config(table_id).num_rows();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
return Status::OK();
@@ -160,12 +160,12 @@ Status AdagradShapes(shape_inference::InferenceContext *c) {
int table_id;
TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
- int64 num_tables = config.table_config_size();
+ int64 num_tables = config.table_descriptor_size();
if (table_id >= num_tables) {
return errors::InvalidArgument("Table id >= num_tables");
}
- int64 width = config.table_config(table_id).width();
- int64 num_rows = config.table_config(table_id).num_rows();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_rows = config.table_descriptor(table_id).vocabulary_size();
TF_RETURN_IF_ERROR(c->set_output("parameters", {c->Matrix(num_rows, width)}));
TF_RETURN_IF_ERROR(
@@ -244,11 +244,11 @@ Status ActivationShapes(shape_inference::InferenceContext *c) {
if (!config.ParseFromString(config_string)) {
return errors::InvalidArgument("Malformed tpu_embedding_config.");
}
- int64 batch_size = config.batch_size();
- int64 num_tables = config.table_config_size();
+ int64 batch_size = config.batch_size_per_tensor_core();
+ int64 num_tables = config.table_descriptor_size();
for (int table_id = 0; table_id < num_tables; ++table_id) {
- int64 width = config.table_config(table_id).width();
- int64 num_features = config.table_config(table_id).num_features();
+ int64 width = config.table_descriptor(table_id).dimension();
+ int64 num_features = config.table_descriptor(table_id).vocabulary_size();
c->set_output(table_id, c->Matrix(batch_size * num_features, width));
}
return Status::OK();
diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
index 98cc31f18d..b4b06a40a2 100644
--- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc
@@ -142,9 +142,8 @@ Status WriteTensorboardTPUProfile(const string& logdir, const string& run,
TF_RETURN_IF_ERROR(DumpTraceToLogDirectory(profile_run_dir, host_prefix,
response.encoded_trace(), os));
}
- if (response.has_op_profile() &&
- (response.op_profile().has_by_program_structure() ||
- response.op_profile().has_by_category())) {
+ if (response.has_op_profile() && (response.op_profile().has_by_program() ||
+ response.op_profile().has_by_category())) {
TF_RETURN_IF_ERROR(DumpOpProfileToLogDirectory(profile_run_dir, host_prefix,
response.op_profile(), os));
}
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index feb177a7da..68cf510e71 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -4,12 +4,14 @@ package tensorflow.tpu.op_profile;
// Profile is the top-level data that summarizes a program.
message Profile {
+ reserved 2;
+ reserved "by_program_structure";
+ reserved 3;
+ reserved "per_program";
// Root of a profile broken down by instruction category.
Node by_category = 1;
- // Root of a profile broken down by program structure.
- Node by_program_structure = 2;
- // Per program profile, indexed by hlo module name of the program.
- map<string, Node> per_program = 3;
+ // Root of a profile broken down by program.
+ Node by_program = 4;
}
// An entry in the profile tree. (An instruction, or set of instructions).
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
index 438f442848..63641e00c5 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
@@ -116,12 +116,13 @@ def main(unused_argv=None):
elif tpu_cluster_resolver is not None:
workers_list = get_workers_list(tpu_cluster_resolver)
- if not FLAGS.logdir:
+ if not FLAGS.logdir and not FLAGS.monitoring_level:
sys.exit('logdir must be provided.')
executable_path = os.path.join(os.path.dirname(__file__), EXECUTABLE)
- logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir))
cmd = [executable_path]
- cmd.append('--logdir=' + logdir)
+ if FLAGS.logdir is not None:
+ logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir))
+ cmd.append('--logdir=' + logdir)
cmd.append('--service_addr=' + service_addr)
cmd.append('--workers_list=' + workers_list)
cmd.append('--duration_ms=' + str(FLAGS.duration_ms))
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index d4ccb0f246..2415c46718 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.10.0'
+_VERSION = '1.11.0'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h
index aee094177b..90d34b5ef1 100644
--- a/tensorflow/contrib/tpu/profiler/version.h
+++ b/tensorflow/contrib/tpu/profiler/version.h
@@ -16,6 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
-#define TPU_PROFILER_VERSION "1.10.0"
+#define TPU_PROFILER_VERSION "1.11.0"
#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/contrib/tpu/proto/BUILD
index 598b73b438..c20cab844c 100644
--- a/tensorflow/contrib/tpu/proto/BUILD
+++ b/tensorflow/contrib/tpu/proto/BUILD
@@ -10,12 +10,15 @@ load(
)
tf_proto_library(
- name = "tpu_embedding_config_proto",
+ name = "tpu_embedding_configuration_proto",
srcs = [
- "tpu_embedding_config.proto",
+ "tpu_embedding_configuration.proto",
],
cc_api_version = 2,
- protodeps = [":optimization_parameters_proto"],
+ protodeps = [
+ ":tpu_embedding_output_layout_proto",
+ ":optimization_parameters_proto",
+ ],
visibility = ["//visibility:public"],
)
@@ -29,6 +32,15 @@ tf_proto_library(
)
tf_proto_library(
+ name = "tpu_embedding_output_layout_proto",
+ srcs = [
+ "tpu_embedding_output_layout.proto",
+ ],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library(
name = "topology_proto",
srcs = [
"topology.proto",
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
deleted file mode 100644
index 3476cc8953..0000000000
--- a/tensorflow/contrib/tpu/proto/tpu_embedding_config.proto
+++ /dev/null
@@ -1,66 +0,0 @@
-syntax = "proto3";
-
-package tensorflow.tpu;
-
-import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
-
-// The TPUEmbeddingConfiguration contains specification of TPU Embedding lookups
-// and gradient updates separate from the TF Graph.
-message TPUEmbeddingConfiguration {
- // model_mode specifies whether the model is to be run in training or
- // inference. In inference mode, gradient updates to embedding tables are not
- // performed.
- enum ModelMode {
- INVALID = 0;
- TRAINING = 1;
- INFERENCE = 2;
- }
-
- ModelMode model_mode = 1;
-
- // num_hosts is the number of host CPU systems in the training/inference job.
- // Each embedding table must be sharded into num_hosts separate Variables,
- // placed separately on the num_hosts CPU devices in the cluster. Sharding
- // will be performed equivalently to the 'div' sharding_strategy option of
- // embedding_lookup() and embedding_lookup_sparse().
- int32 num_hosts = 2;
-
- // The total number of TensorNodes. This is equal to num_hosts times the
- // number of TensorNodes attached to each host.
- int32 num_tensornodes = 3;
-
- // The number of training examples per TensorNode.
- int32 batch_size = 4;
-
- // Each Embedding
- message TPUEmbeddingTable {
- // Name of the embedding table. This will be used to name Variables in the
- // Tensorflow Graph.
- string name = 1;
-
- // Number of rows of the embedding table. The Variable created to hold the
- // learned embedding table values will have shape (num_rows, width).
- int32 num_rows = 3;
-
- // Width of the embedding table. The Variable created to hold the
- // learned embedding table values will have shape (num_rows, width).
- int32 width = 4;
-
- // Number of distinct embedding activation vectors per training example
- // produced by lookups into this table during model evaluation. For each
- // table, the Graph will receive an activations Tensor of shape
- // (batch_size * table.num_features, table.width).
- // For example, num_features = 1 produces equivalent behavior to a single
- // tf.nn.embedding_lookup() call. In the case of 'multivalent' embeddings,
- // (i.e. tf.nn.embedding_lookup_sparse()) which compute weighted averages of
- // embedding table rows, num_features is the number of vectors produced
- // after averaging. In sequence models num_features is typically equal
- // to the sequence length, since each sequence element must be represented
- // separately to the convolutional or recurrent network.
- int32 num_features = 5;
-
- OptimizationParameters optimization_parameters = 6;
- }
-
- repeated TPUEmbeddingTable table_config = 5;
-}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
new file mode 100644
index 0000000000..da19b135d7
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
@@ -0,0 +1,95 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
+import "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto";
+
+message TPUEmbeddingConfiguration {
+ // Description of the various embedding tables.
+ message TableDescriptor {
+ // Name of the table.
+ string name = 1;
+ // Size of the vocabulary (i.e., number of rows) in the table.
+ int32 vocabulary_size = 2;
+ // The embedding dimension (i.e., the width of the embedding table).
+ int32 dimension = 3;
+ // Number of features mapped to this table.
+ int32 num_features = 4;
+ // Details of the learning algorithm used to update the embedding
+ // parameters.
+ OptimizationParameters optimization_parameters = 5;
+ }
+ repeated TableDescriptor table_descriptor = 1;
+
+ // Mode. Should the embedding layer program be run for inference (just forward
+ // pass), training (both forward and backward pass) or just the backward_pass.
+ enum Mode {
+ UNSPECIFIED = 0;
+ INFERENCE = 1;
+ TRAINING = 2;
+ BACKWARD_PASS_ONLY = 3;
+ }
+ Mode mode = 2;
+
+ // Number of samples in each batch of embedding layer activations sent to
+ // the TensorCore.
+ int32 batch_size_per_tensor_core = 3;
+
+ // Number of TPU hosts used for inference/training.
+ int32 num_hosts = 4;
+
+ // Number of TensorCore used for inference/training.
+ int32 num_tensor_cores = 5;
+
+ // Sharding strategy of the embedding tables among the hosts.
+ // If the sharding_strategy is "mod", each id is assigned to host
+ // "id % num_hosts". For instance, 13 ids are split across 5 hosts as:
+ // [[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]].
+ // If the sharding_strategy is "div", ids are assigned to hosts in a
+ // contiguous manner. In this case, 13 ids are split across 5 hosts as:
+ // [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].
+ // In both the strategies, if the id space does not evenly divide the number
+ // of hosts, each of the first "table_descriptor.num_ids % num_hosts" hosts
+ // will be assigned one more id.
+ // This partitioning strategy exactly follows that in the embedding_lookup
+ // TensorFlow function at tensorflow/python/ops/embedding_ops.py.
+ enum ShardingStrategy {
+ DIV_DEFAULT = 0;
+ MOD = 1;
+ }
+ ShardingStrategy sharding_strategy = 6;
+
+ // This parameter determines if the execution of the sparse core will be
+ // pipelined with that of the TensorCore. This parameter only affects results
+ // when mode=TRAINING. If mode=INFERENCE or BACKWARD_PASS_ONLY, this parameter
+ // does not affect execution and hence, is a don't care value.
+ //
+ // false: The execution of the sparse core is not pipelined with that of the
+ // TensorCore. The forward pass of every step on the sparse core is executed
+ // only after the backward pass of the previous step is complete. And the
+ // backward pass on the sparse core is executed only after the embedding
+ // gradients have been computed on the TensorCore on every step. This ensures
+ // that the activations on every step observe the gradient updates from the
+ // previous step on both the sparse core and the TensorCore.
+ //
+ // true: The execution of the sparse core is pipelined with that of the
+ // TensorCore. The forward pass of every step on the sparse core can be
+ // executed after the forward pass of the previous step is complete without
+ // waiting for the backward pass. This improves the utilization of the sparse
+ // core allowing it to process step N+1 while the embedding gradients for step
+ // N are computed on the TensorCore. The backward pass of every step on the
+ // sparse core is executed directly after the forward pass for the next step
+ // is complete. The drawback is that embedding activations for step N+1 do not
+ // observe the embedding gradient updates from step N. This could affect model
+ // quality if step N and N+1 involve the same set of embedding IDs. However,
+ // since the embedding updates are sparse, this is generally not considered a
+ // problem.
+ bool pipeline_execution_with_tensor_core = 7;
+
+ // Extended output layout information; if not provided, a compatibility mode
+ // will use defaults that match the old layout. Providing a value for this
+ // field is EXPERIMENTAL and most ways of filling it will probably break. Do
+ // not set it unless you know what you are doing.
+ TPUEmbeddingOutputLayout output_layout = 8;
+}
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
new file mode 100644
index 0000000000..aed30b2f22
--- /dev/null
+++ b/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
@@ -0,0 +1,75 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// In the comments here, "layout" refers to the top-level EmbeddingOutputLayout
+// proto contained in the TPUEmbeddingConfiguration.
+
+// The embedding output consists of a list of tensors, each specified by an
+// EmbeddingOutputTensor proto within the EmbeddingOutputLayout (the "output"
+// field). Each table and feature lookup is then placed into some number of
+// particular positions within some output tensor (identified by "tensor_index"
+// within OutputLocation). The tree of table lookups, feature lookups, and
+// output locations is specified by the
+// "table(table_id).feature(feature_id).output_location" repeated fields within
+// EmbeddingOutputLayout.
+
+message TPUEmbeddingOutputLayout {
+ // Location of one copy of the feature's data.
+ message OutputLocation {
+ // Which output tensor this copy of the feature will go into. Must be
+ // between 0 and layout.output_size().
+ int32 tensor_index = 1;
+
+ // Offset in dimension 0 for this feature copy. Must be between 0 and
+ // layout.output(tensor_index).dim0_size_per_sample().
+ int32 dim0_offset = 2;
+
+ // Offset in dimension 1 for this feature copy. Must be between 0 and
+ // layout.output(tensor_index).dim1_size() - table width; repeated or
+ // partially/fully overlapping values are allowed and results in the same
+ // range will be summed (with the gradients replicated in the backward
+ // pass).
+ int32 dim1_offset = 3;
+ }
+
+ // Description of the output placement for one feature.
+ message FeatureDescriptor {
+ // Typically, only one copy of each feature is used, but multiple are
+ // allowed and the same data will be copied to all of them (with the
+ // gradients summed in the backward pass).
+ repeated OutputLocation output_location = 1;
+ }
+
+ // Description of the output placement for features of one table.
+ message TableDescriptor {
+ // Output locations for each feature loaded from this table.
+ repeated FeatureDescriptor feature = 1;
+ }
+ // Output locations for each feature of each table.
+ repeated TableDescriptor table = 1;
+
+ // Data layout and shape computation information for a single output tensor.
+ // Any unused locations in the tensor will be filled with zeros, and
+ // corresponding gradients will be ignored.
+
+ // Size and layout information for 2-D tensors.
+ message TwoDOutputTensor {
+ // Multiplier for output dimension 0 size; used to match legacy format that
+ // stacks features within a sample in dimension 0.
+ int32 dim0_size_per_sample = 2;
+
+ // The size (in dimension 1) of this output tensor.
+ int32 dim1_size = 1;
+ }
+
+ // Format information for a single output tensor.
+ message EmbeddingOutputTensor {
+ oneof output_format {
+ TwoDOutputTensor two_d = 4;
+ }
+ }
+
+ // Shape and layout information for each tensor.
+ repeated EmbeddingOutputTensor output = 2;
+}
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index d92a0652bb..a1aee69691 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -95,7 +95,7 @@ if platform.system() != "Windows":
]
def cross_replica_sum(x, group_assignment=None, name=None):
- """Sum the input tensor accorss replicas according to group_assignment.
+ """Sum the input tensor across replicas according to group_assignment.
Args:
x: The local tensor to the sum.
@@ -112,6 +112,31 @@ if platform.system() != "Windows":
return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
+ def collective_permute(x, source_target_pairs, name=None):
+ """Permute the input tensor across replicas given source_target_pairs.
+
+ For each source_target_pair <a, b>, we send replica a's input to replica b.
+ Each replica id must only appear once in the source column. Also it must
+ only appear once in the target column.
+ For the replica id not in the target column, this op returns a zero tensor
+ with the same shape and dtype of the input x.
+
+ For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
+ source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
+ `[0, A, B, C]`.
+
+ Args:
+ x: The local tensor to be permuted.
+ source_target_pairs: 2d int lists with shape [num_pairs, 2].
+ source_target_pairs[i][0] represents the source replica id and
+ source_target_pairs[i][1] represents the target replica id.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is permuted.
+ """
+ return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
+
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.
diff --git a/tensorflow/contrib/tpu/python/tpu/device_assignment.py b/tensorflow/contrib/tpu/python/tpu/device_assignment.py
index 471b1fa46c..b9e2a4287a 100644
--- a/tensorflow/contrib/tpu/python/tpu/device_assignment.py
+++ b/tensorflow/contrib/tpu/python/tpu/device_assignment.py
@@ -72,13 +72,12 @@ class DeviceAssignment(object):
self._invert_topology(topology))
topology_rank = self._topology_tasks.ndim
- if core_assignment.ndim != topology_rank + 2:
- raise ValueError("core_assignment must be a rank {} numpy array".format(
- topology_rank + 2))
+ if core_assignment.ndim != 3:
+ raise ValueError("core_assignment must be a rank 3 numpy array, "
+ "got shape {}".format(core_assignment.shape))
self._num_replicas = core_assignment.shape[0]
- self._computation_shape = np.array(
- core_assignment.shape[1:-1], dtype=np.int32)
+ self._num_cores_per_replica = core_assignment.shape[1]
if core_assignment.shape[-1] != topology_rank:
raise ValueError(
@@ -107,18 +106,15 @@ class DeviceAssignment(object):
"""Computes a nested dict which maps task and logical core to replicas."""
task_and_cores_to_replicas = {}
for replica in xrange(core_assignment.shape[0]):
- for dx in xrange(core_assignment.shape[1]):
- for dy in xrange(core_assignment.shape[2]):
- for dz in xrange(core_assignment.shape[3]):
- x, y, z = core_assignment[replica, dx, dy, dz, :]
- task_id = topology_tasks[x, y, z]
- if task_id not in task_and_cores_to_replicas:
- task_and_cores_to_replicas[task_id] = {}
- logical_core = (dx, dy, dz)
- if logical_core not in task_and_cores_to_replicas[task_id]:
- task_and_cores_to_replicas[task_id][logical_core] = set()
-
- task_and_cores_to_replicas[task_id][logical_core].add(replica)
+ for logical_core in xrange(core_assignment.shape[1]):
+ x, y, z = core_assignment[replica, logical_core, :]
+ task_id = topology_tasks[x, y, z]
+ if task_id not in task_and_cores_to_replicas:
+ task_and_cores_to_replicas[task_id] = {}
+ if logical_core not in task_and_cores_to_replicas[task_id]:
+ task_and_cores_to_replicas[task_id][logical_core] = set()
+
+ task_and_cores_to_replicas[task_id][logical_core].add(replica)
task_to_sorted_replica_id = {}
@@ -136,23 +132,9 @@ class DeviceAssignment(object):
return self._topology
@property
- def computation_shape(self):
- """The computation shape.
-
- Returns:
- A rank-1 int32 numpy array with size equal to the TPU topology rank.
- Describes the logical shape in numbers of core of each replica of the
- computation in the TPU topology.
-
- Returns:
- The computation shape.
- """
- return self._computation_shape
-
- @property
def num_cores_per_replica(self):
"""The number of cores per replica."""
- return np.prod(self.computation_shape)
+ return self._num_cores_per_replica
@property
def num_replicas(self):
@@ -164,33 +146,22 @@ class DeviceAssignment(object):
"""The logical to physical core mapping.
Returns:
- A numpy array of rank `topology_rank + 2`, with shape
- `[num_replicas] + computation_shape + [topology_rank]`. Maps
- (replica, logical core coordinates) pairs to physical topology
- coordinates.
+ An integer numpy array of rank 3, with shape
+ `[num_replicas, num_cores_per_replica, topology_rank]`. Maps
+ (replica, logical core) pairs to physical topology coordinates.
"""
return self._core_assignment
def _coordinates(self, replica, logical_core):
"""Returns the physical topology coordinates of a logical core."""
- if logical_core is None:
- logical_core = np.array([0, 0, 0], np.int32)
- else:
- logical_core = np.asarray(logical_core)
-
- if any(logical_core < 0) or any(logical_core >= self.computation_shape):
- raise ValueError("Invalid core {}; computation shape is {}".format(
- logical_core, self.computation_shape))
-
- logical_offset = tuple([replica] + logical_core.tolist() + [slice(3)])
- return tuple(self.core_assignment[logical_offset])
+ return tuple(self.core_assignment[replica, logical_core, :])
def lookup_replicas(self, task_id, logical_core):
"""Lookup replica ids by task number and logical core.
Args:
task_id: TensorFlow task number.
- logical_core: A tuple of three integers which represents a logical core.
+ logical_core: An integer, identifying a logical core.
Returns:
A sorted list of the replicas that are attached to that task and
logical_core.
@@ -205,17 +176,17 @@ class DeviceAssignment(object):
"Can not find any replica in task: {} contains logical_core: {} ".
format(task_id, logical_core))
- def tpu_ordinal(self, replica=0, logical_core=None):
+ def tpu_ordinal(self, replica=0, logical_core=0):
"""Returns the ordinal of the TPU device assigned to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return self._topology_devices[coordinates]
- def host_device(self, replica=0, logical_core=None, job=None):
+ def host_device(self, replica=0, logical_core=0, job=None):
"""Returns the CPU device attached to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return _tpu_host_device_name(job, self._topology_tasks[coordinates])
- def tpu_device(self, replica=0, logical_core=None, job=None):
+ def tpu_device(self, replica=0, logical_core=0, job=None):
"""Returns the name of the TPU device assigned to a logical core."""
coordinates = self._coordinates(replica, logical_core)
return _tpu_device_name(job, self._topology_tasks[coordinates],
@@ -228,6 +199,8 @@ def device_assignment(topology,
num_replicas=1):
"""Computes a device_assignment of a computation across a TPU topology.
+ Attempts to choose a compact grid of cores for locality.
+
Returns a `DeviceAssignment` that describes the cores in the topology assigned
to each core of each replica.
@@ -240,12 +213,12 @@ def device_assignment(topology,
`initialize_system` using `Session.run`. Either a serialized
`TopologyProto` or a `Topology` object may be passed. Note: you must
evaluate the `Tensor` first; you cannot pass an unevaluated `Tensor` here.
- computation_shape: A rank 1 int32 numpy array of size 3, describing the
- shape of the computation's block of cores. If None, the
- `computation_shape` is `[1, 1, 1]`.
- computation_stride: A rank 1 int32 numpy array of size 3, describing the
- inter-core spacing of the `computation_shape` cores in the TPU topology.
- If None, the `computation_stride` is `[1, 1, 1]`.
+ computation_shape: A rank 1 int32 numpy array with size equal to the
+ topology rank, describing the shape of the computation's block of cores.
+ If None, the `computation_shape` is `[1] * topology_rank`.
+ computation_stride: A rank 1 int32 numpy array of size `topology_rank`,
+ describing the inter-core spacing of the `computation_shape` cores in the
+ TPU topology. If None, the `computation_stride` is `[1] * topology_rank`.
num_replicas: The number of computation replicas to run. The replicas will
be packed into the free spaces of the topology.
@@ -271,21 +244,21 @@ def device_assignment(topology,
topology_rank = len(topology.mesh_shape)
mesh_shape = topology.mesh_shape
if computation_shape is None:
- computation_shape = np.array([1, 1, 1], dtype=np.int32)
+ computation_shape = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_shape = np.asarray(computation_shape, dtype=np.int32)
if computation_stride is None:
- computation_stride = np.array([1, 1, 1], dtype=np.int32)
+ computation_stride = np.array([1] * topology_rank, dtype=np.int32)
else:
computation_stride = np.asarray(computation_stride, dtype=np.int32)
- if computation_shape.shape != (3,):
- raise ValueError("computation_shape must have shape [3]; got {}".format(
- computation_shape.shape))
- if computation_stride.shape != (3,):
- raise ValueError("computation_stride must have shape [3]; got {}".format(
- computation_stride.shape))
+ if computation_shape.shape != (topology_rank,):
+ raise ValueError("computation_shape must have shape [{}]; got {}".format(
+ topology_rank, computation_shape.shape))
+ if computation_stride.shape != (topology_rank,):
+ raise ValueError("computation_stride must have shape [{}]; got {}".format(
+ topology_rank, computation_stride.shape))
if any(computation_shape < 1):
raise ValueError(
@@ -315,28 +288,41 @@ def device_assignment(topology,
num_replicas, max_replicas, computation_shape, computation_stride,
mesh_shape))
- # Choose a compact layout for the cores. Choose the smaller dimension in the
- # topology to be close to the square root of the number of replicas.
- num_chips = int(math.ceil(num_replicas / replica_counts[2]))
- target_size = int(math.ceil(math.sqrt(num_chips)))
-
- # Prefer an even size, if possible. Odd numbered rows head back towards the
- # first column, so it's best if the last row has an odd index.
- if target_size % 2 != 0:
- target_size -= 1
- y_size = min(replica_counts[1], target_size)
- if y_size * replica_counts[0] < num_chips:
- y_size = replica_counts[1]
+ def ceil_of_ratio(n, m):
+ return (n + m - 1) // m
+
+ replica_shape = [0] * topology_rank
+ if num_replicas > 0:
+ remaining_replicas = num_replicas
+ remaining_dims = topology_rank
+
+ # Choose dimensions as close to an equal cube as possible, in order of
+ # increasing dimension size. By visiting dimensions in increasing size, we
+ # assign the most constrained dimension first, so we won't make infeasible
+ # choices.
+ #
+ # As a secondary sort order, visit the dimensions in reverse order. This
+ # means we try to use both cores on the same chip in preference to two cores
+ # on different chips.
+ for x, ni in sorted(((x, -i) for (i, x) in enumerate(replica_counts))):
+ i = -ni
+ target_size = int(math.ceil(remaining_replicas**(1.0 / remaining_dims)))
+ replica_shape[i] = min(target_size, x)
+ remaining_replicas = ceil_of_ratio(remaining_replicas, replica_shape[i])
+ remaining_dims -= 1
+
+ assert remaining_replicas == 1 and remaining_dims == 0
# Assigns an offset to each replica such that no two replicas overlap.
- replica_offsets = np.full([num_replicas, 3], -1, dtype=np.int32)
+ replica_offsets = np.full([num_replicas, topology_rank], -1, dtype=np.int32)
for replica in xrange(num_replicas):
- # Chooses a replica number in X/Y/Z axes.
- z = replica % replica_counts[2]
- t = replica // replica_counts[2]
- y = t % y_size
- x = t // y_size
- replica_pos = np.array([x, y, z], dtype=np.int32)
+ # Chooses a replica number in each axis.
+ t = replica
+ pos = []
+ for dim in replica_shape[::-1]:
+ pos.append(t % dim)
+ t //= dim
+ replica_pos = np.array(pos[::-1], dtype=np.int32)
# Determines where that replica starts in each axis.
outer = replica_pos // computation_stride
@@ -351,6 +337,6 @@ def device_assignment(topology,
indices = np.concatenate(
[i[..., np.newaxis] for i in np.meshgrid(*indices, indexing="ij")],
axis=-1)
- assignment = (
- indices + replica_offsets[:, np.newaxis, np.newaxis, np.newaxis, :])
+ indices = indices.reshape((-1, topology_rank))
+ assignment = indices + replica_offsets[:, np.newaxis, :]
return DeviceAssignment(topology, core_assignment=assignment)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index d8c3872363..f67e0e6aca 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -25,10 +25,9 @@ flattened = tf.keras.layers.Flatten()(c1)
logits = tf.keras.layers.Dense(10, activation='softmax')(flattened)
model = tf.keras.Model(inputs=[image], outputs=[logits])
-strategy = keras_support.TPUDistributionStrategy(num_cores_per_host=8)
-model = keras_support.tpu_model(model,
- strategy=strategy,
- tpu_name_or_address=tpu_name)
+resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
+strategy = keras_support.TPUDistributionStrategy(resolver)
+model = keras_support.tpu_model(model, strategy=strategy)
# Only TF optimizers are currently supported.
model.compile(optimizer=tf.train.AdamOptimizer(), ...)
@@ -47,12 +46,12 @@ from __future__ import print_function
import abc
import collections
-import contextlib
import re
import sys
import time
import numpy as np
+import six
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.framework.python.framework import experimental
@@ -69,6 +68,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -76,6 +76,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
@@ -89,34 +90,34 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
-_SESSIONS = {}
-
-
-def tpu_session(cluster_resolver):
+def setup_tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
- global _SESSIONS
master = cluster_resolver.master()
- if master not in _SESSIONS:
- cluster_spec = cluster_resolver.cluster_spec()
- config = config_pb2.ConfigProto(isolate_session_state=True)
- if cluster_spec:
- config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- logging.info('Connecting to: %s', master)
- graph = ops.Graph()
- session = tf_session.Session(graph=graph, target=master, config=config)
- with graph.as_default():
- session.run(tpu.initialize_system())
+ # Use the existing session if we're already connected to this TPU
+ if (K.get_session()._target == master and
+ getattr(K.get_session(), '_tpu_initialized', None)):
+ return
+
+ cluster_spec = cluster_resolver.cluster_spec()
+ config = config_pb2.ConfigProto(isolate_session_state=True)
+ if cluster_spec:
+ config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- _SESSIONS[master] = session
- return _SESSIONS[master]
+ logging.info('Initialize')
+ tpu_session = tf_session.Session(target=master, config=config)
+ tpu_session.run(tpu.initialize_system())
+ tpu_session._tpu_initialized = True
+ # N.B. We have to call `K.set_session()` AND set our session as the
+ # TF default. `K.get_session()` surprisingly does not return the value
+ # supplied by K.set_session otherwise.
+ K.set_session(tpu_session)
-def reset_tpu_sessions():
- _SESSIONS.clear()
try:
from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
@@ -133,9 +134,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
- master,
- cluster_def=cluster_def,
- query_topology=False))
+ master, cluster_def=cluster_def, query_topology=False))
return tpu_system_metadata
@@ -156,6 +155,8 @@ class TPUDistributionStrategy(object):
replication, typically using all avaiable TPU cores. If overwrites as
`True`, force the model replication using single core, i.e., no
replication.
+ Raises:
+ Exception: No TPU Found on the given worker.
"""
if tpu_cluster_resolver is None:
@@ -171,7 +172,8 @@ class TPUDistributionStrategy(object):
for device in metadata.devices:
if 'TPU:0' in device.name:
self._worker_name = worker_re.search(device.name).group(1)
- break
+ return
+ raise Exception('No TPU found on given worker.')
def _make_assignment_for_model(self, cpu_model):
"""Makes a `TPUAssignment` for the passed in `cpu_model`."""
@@ -182,8 +184,7 @@ class TPUDistributionStrategy(object):
'Degrading to a single core.')
num_cores = 1
- return TPUAssignment(
- worker_name=self._worker_name, num_cores=num_cores)
+ return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores)
class TPUAssignment(object):
@@ -229,6 +230,39 @@ class TPUEmbedding(embeddings.Embedding):
return math_ops.tensordot(inputs, self.embeddings, 1)
+def _cross_replica_concat(tensor, core_id, num_cores, name):
+ """Concatenate `tensor` across cores.
+
+ Args:
+ tensor: The tensor to be concatenated. Must be [int32 and float32].
+ core_id: Tensor indicating the current TPU core.
+ num_cores: Python int. The total number of TPU cores in the system.
+ name: The string name to print for debugging.
+
+ Returns:
+ The same concatenated Tensor on each core.
+ """
+
+ input_dtype = tensor.dtype
+ if input_dtype not in [dtypes.float32, dtypes.int32]:
+ raise TypeError('For model replication, only (float32 and int32) is '
+ 'supported for model outputs and targets. Got {} for '
+ '{}.'.format(input_dtype, name))
+
+ batch_size = tensor.shape[0]
+ mask = math_ops.to_float(math_ops.equal(range(num_cores), core_id))
+ mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims)
+ result = mask * math_ops.to_float(tensor)
+ local_tensor_with_holes = array_ops.reshape(result,
+ [-1] + result.shape.as_list()[2:])
+ concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes)
+ concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:]))
+
+ if concat_tensor != input_dtype:
+ concat_tensor = math_ops.cast(concat_tensor, input_dtype)
+ return concat_tensor
+
+
class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
"""An optimizer that averages gradients across TPU shards."""
@@ -246,9 +280,9 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
super(KerasCrossShardOptimizer, self).__init__()
self._name = name
self._opt = opt
+ logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights)
def get_updates(self, loss, params):
- logging.info('Get updates: %s', loss)
self._opt.get_gradients = self.get_gradients
return self._opt.get_updates(loss, params)
@@ -257,17 +291,15 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
- def set_weights(self, weights):
- # TODO(power): Figure out whether we really need this given there is no
- # caller for this API yet.
- self._opt.set_weights()
-
def get_weights(self):
return self._opt.get_weights()
- @property
- def lr(self):
- return self._opt.lr
+ def get_config(self):
+ return self._opt.get_config()
+
+ # Defer remaining operations to the underlying optimizer
+ def __getattr__(self, key):
+ return getattr(self._opt, key)
class TPUModelOp(
@@ -293,6 +325,36 @@ def _replicated_optimizer(opt):
return KerasCrossShardOptimizer(opt)
+def _clone_metrics(metrics):
+ """Returns a copy of metrics. A copy is created for stateful metrics."""
+ if metrics is None:
+ return None
+ with variable_scope.variable_scope(
+ 'metrics', reuse=variable_scope.AUTO_REUSE):
+ return [
+ m.__class__.from_config(m.get_config()) if isinstance(
+ m, metrics_module.Metric) else m for m in metrics
+ ]
+
+
+def _clone_optimizer(optimizer, config=None):
+ """Returns a cloned optimizer with the provided optimizer.config or config."""
+ if not isinstance(optimizer, keras_optimizers.Optimizer):
+ # In the first call to tpu_model(model), Keras may not have wrapped the TF
+ # optimizer in the TFOptimizer helper, e.g., the given model isn't compiled
+ # or optimizer isn't set, and later generated tpu_model compiles with a TF
+ # optimizer.
+ return optimizer
+
+ if isinstance(optimizer, keras_optimizers.TFOptimizer):
+ return keras_optimizers.TFOptimizer(optimizer.optimizer)
+
+ if config is None:
+ config = optimizer.get_config()
+ logging.info('Cloning %s %s', optimizer.__class__.__name__, config)
+ return optimizer.__class__.from_config(config)
+
+
class TPURewriteContext(object):
"""Prepare the environment for a Keras model during `tpu.rewrite`.
@@ -381,6 +443,7 @@ class TPURewriteContext(object):
return (r, q)
else:
raise ValueError('Invalid shape passed to qr: %s' % input_shape)
+
gen_linalg_ops.qr = qr
ops.name_scope = _name_scope
@@ -396,9 +459,9 @@ class TPURewriteContext(object):
gen_linalg_ops.qr = self._default_qr
-class SizedInfeed(collections.namedtuple('SizedInfeed',
- ['sharded_infeed_tensors',
- 'infeed_ops'])):
+class SizedInfeed(
+ collections.namedtuple('SizedInfeed',
+ ['sharded_infeed_tensors', 'infeed_ops'])):
"""Represents an instantiation of the infeed ops for a concrete input shape.
sharded_infeed_tensors: A data structure of Tensors used to represent the
@@ -584,12 +647,13 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_tensors, [spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
- return SizedInfeed(infeed_ops=infeed_op,
- sharded_infeed_tensors=shard_infeed_tensors)
+ return SizedInfeed(
+ infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors)
class TPUDatasetInfeedManager(TPUInfeedManager):
"""Manages infeed for a `tf.data.Dataset` into a TPU computation.
+
"""
class DatasetInfeedInstance(TPUInfeedInstance):
@@ -607,18 +671,17 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
return {}
# pylint: disable=redefined-outer-name
- def __init__(self, dataset, tpu_assignment, tpu_session):
+ def __init__(self, dataset, tpu_assignment, mode):
"""Constructs a TPUDatasetInfeedManager.
- Must be called within a `KerasTPUModel.tpu_session` context!
-
Args:
dataset: A `tf.data.Dataset` to infeed.
tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
- tpu_session: The `tf.Session` object used for running the TPU model.
+ mode: ModeKeys enum.
"""
self._verify_dataset_shape(dataset)
+
self._dataset = dataset
self._tpu_assignment = tpu_assignment
dummy_x_shape = dataset.output_shapes[0].as_list()
@@ -626,7 +689,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
dummy_y_shape = dataset.output_shapes[1].as_list()
dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset.make_initializable_iterator()
- tpu_session.run(self._iterator.initializer)
+ K.get_session().run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
@@ -639,10 +702,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
- self._dummy_x = np.zeros(dummy_x_shape,
- dtype=dataset.output_types[0].as_numpy_dtype)
- self._dummy_y = np.zeros(dummy_y_shape,
- dtype=dataset.output_types[1].as_numpy_dtype)
+ self._dummy_x = np.zeros(
+ dummy_x_shape, dtype=dataset.output_types[0].as_numpy_dtype)
+ self._dummy_y = np.zeros(
+ dummy_y_shape, dtype=dataset.output_types[1].as_numpy_dtype)
input_specs = []
if isinstance(self._iterator.output_shapes, tuple):
@@ -658,6 +721,10 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
self._iterator.output_types)
input_specs.append(spec)
+ # Pre-process the inputs and get_next_ops before caching.
+ input_specs, self._get_next_ops = (
+ _inject_tpu_inputs_for_dataset(
+ tpu_assignment, mode, input_specs, self._get_next_ops))
self._infeed_instance = self.DatasetInfeedInstance(input_specs)
def _verify_dataset_shape(self, dataset):
@@ -669,9 +736,8 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
raise ValueError('The dataset must return a tuple of tf.Tensors, '
'instead it returns: %s' % dataset.output_classes)
if len(dataset.output_classes) != 2:
- raise ValueError(
- 'The dataset must return a 2-element tuple, got '
- '%s output classes instead.' % (dataset.output_classes,))
+ raise ValueError('The dataset must return a 2-element tuple, got '
+ '%s output classes instead.' % (dataset.output_classes,))
for i, cls in enumerate(dataset.output_classes):
if cls != ops.Tensor:
raise ValueError('The dataset returned a non-Tensor type (%s) at '
@@ -680,8 +746,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
if not shape:
raise ValueError('The dataset returns a scalar tensor in '
'tuple index %d. Did you forget to batch? '
- '(Output shapes: %s).' % (i,
- dataset.output_shapes))
+ '(Output shapes: %s).' % (i, dataset.output_shapes))
for j, dim in enumerate(shape):
if dim.value is None:
if j == 0:
@@ -721,8 +786,73 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
[spec.shape for spec in input_specs],
name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
device_ordinal=shard_id))
- return SizedInfeed(infeed_ops=infeed_ops,
- sharded_infeed_tensors=shard_infeed_tensors)
+ return SizedInfeed(
+ infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors)
+
+
+def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
+ input_specs, get_next_ops):
+ """Append core information to the set of dataset inputs."""
+ # This is used during compilation to identify the current TPU core and enable
+ # concatenation operations across cores.
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return input_specs, get_next_ops
+
+ # Dataset inputs operate on per core basis.
+ per_core_batch_size = input_specs[0].shape.as_list()[0]
+
+ # Insert, at head, the tensor for core_id.
+ assert len(get_next_ops) == tpu_assignment.num_towers
+ for i in range(tpu_assignment.num_towers):
+ core_id_constant = constant_op.constant(
+ np.array([i] * per_core_batch_size).astype('int32'),
+ dtype=dtypes.int32,
+ name='cord_id_constant')
+ get_next_ops[i] = [core_id_constant] + list(get_next_ops[i])
+
+ # Insert the input spec at head also.
+ input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32)
+ ] + input_specs
+
+ return input_specs, get_next_ops
+
+
+def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, input_tensors, inputs):
+ """Append core information to the set of inputs."""
+ # This is used during compilation to identify the current TPU core and enable
+ # concatenation operations across cores.
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return input_tensors, inputs
+
+ # Puts a place holder in input spec.
+ core_id_place_holder = array_ops.placeholder(
+ dtype=dtypes.int32, shape=[1], name='core_id')
+ input_tensors = [core_id_place_holder] + input_tensors
+
+ # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the
+ # core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id
+ # (duplicated).
+ num_cores = tpu_assignment.num_towers
+ per_core_batch_size = inputs[0].shape[0] // num_cores
+ core_ids = np.arange(num_cores).repeat(per_core_batch_size)
+ inputs = [core_ids] + inputs
+ return input_tensors, inputs
+
+
+def _read_tpu_coreid_from_infeed(mode, infeed_tensors):
+ """Popping out the core ids from infeed."""
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
+ return None, infeed_tensors
+
+ if len(infeed_tensors) <= 1:
+ raise RuntimeError(
+ 'The infeed tensors on TPU core has only {} tensors. '
+ 'This is not expected. Please report a bug.\nTensors: {}'.format(
+ len(infeed_tensors), infeed_tensors))
+
+ core_id = infeed_tensors[0][0] # Pop out the scalar version.
+ rest = infeed_tensors[1:]
+ return core_id, rest
class TPUFunction(object):
@@ -743,12 +873,7 @@ class TPUFunction(object):
self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
-
- # Copy optimizer configuration. This is done prior to `_specialize_model`
- # as the configuration may require evaluating variables in the CPU session.
- self._optimizer_config = None
- if not isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
- self._optimizer_config = self.model.optimizer.get_config()
+ self._cloned_optimizer = None
def _specialize_model(self, input_specs, infeed_manager):
"""Specialize `self.model` (a Keras model) for the given input shapes."""
@@ -775,6 +900,10 @@ class TPUFunction(object):
shapes=[spec.shape for spec in input_specs],
name='infeed-%s' % self.execution_mode)
+ core_id, infeed_tensors = (
+ _read_tpu_coreid_from_infeed(
+ mode=self.execution_mode, infeed_tensors=infeed_tensors))
+
assert len(infeed_tensors) == len(infeed_layers), (
'Infeed inputs did not match model: %s vs %s' % (infeed_layers,
infeed_tensors))
@@ -790,31 +919,51 @@ class TPUFunction(object):
tpu_targets.append(tensor)
# Clone our CPU model, running within the TPU device context.
+ #
+ # We use the id of the original model as a key to avoid weight collisions
+ # (if a user re-runs the same model multiple times, in e.g. Colab).
with TPURewriteContext(tpu_input_map):
- with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
+ with variable_scope.variable_scope('tpu_%s' % id(self.model)):
with keras_tpu_variables.replicated_scope(
self._tpu_assignment.num_towers):
- self._cloned_model = models.clone_model(self.model)
+ if not self._cloned_optimizer:
+ self._cloned_optimizer = _clone_optimizer(
+ self.model.cpu_optimizer)
- # Create a copy of the optimizer for this graph.
- if isinstance(self.model.optimizer, keras_optimizers.TFOptimizer):
- cloned_optimizer = keras_optimizers.TFOptimizer(
- self.model.optimizer.optimizer)
- else:
- logging.info('Cloning %s %s', self.model.optimizer.__class__.__name__,
- self._optimizer_config)
- cloned_optimizer = self.model.optimizer.__class__.from_config(
- self._optimizer_config)
+ self._cloned_model = models.clone_model(self.model)
- if is_training or is_test:
- self._cloned_model.compile(
- optimizer=_replicated_optimizer(cloned_optimizer),
- loss=self.model.loss,
- loss_weights=self.model.loss_weights,
- metrics=self.model.metrics,
- weighted_metrics=self.model.weighted_metrics,
- target_tensors=tpu_targets,
- )
+ # When running on more than one core, concatenate outputs at the end
+ # of processing. In backprop stage, the gradients will be
+ # calculdated according to the local inputs as gradient of
+ # cross-replica-concat being zero for any outputs other than those
+ # from mlocal core so the loss calculation is identical.
+ num_towers = self.model._tpu_assignment.num_towers
+ if num_towers > 1 and (is_training or is_test):
+ new_outputs = [
+ _cross_replica_concat(
+ o, core_id, num_towers,
+ name='model output ({})'.format(o.name))
+ for o in self._cloned_model.outputs
+ ]
+ self._cloned_model.outputs = new_outputs
+ tpu_targets = [
+ _cross_replica_concat(
+ tensor,
+ core_id,
+ num_towers,
+ name='model target ({})'.format(tensor.name))
+ for tensor in tpu_targets
+ ]
+
+ if is_training or is_test:
+ self._cloned_model.compile(
+ optimizer=_replicated_optimizer(self._cloned_optimizer),
+ loss=self.model.loss,
+ loss_weights=self.model.loss_weights,
+ metrics=_clone_metrics(self.model.metrics),
+ weighted_metrics=_clone_metrics(self.model.weighted_metrics),
+ target_tensors=tpu_targets,
+ )
# Compute our outfeed depending on the execution mode
if is_training:
@@ -923,6 +1072,7 @@ class TPUFunction(object):
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
return mgr
+
return TPUNumpyInfeedManager(self.model._tpu_assignment)
def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
@@ -947,13 +1097,14 @@ class TPUFunction(object):
# unique input shape.
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
if shape_key not in self._compilation_cache:
- with self.model.tpu_session():
- logging.info('New input shapes; (re-)compiling: mode=%s, %s',
- self.execution_mode, input_specs)
- new_tpu_model_ops = self._specialize_model(input_specs,
- infeed_manager)
- self._compilation_cache[shape_key] = new_tpu_model_ops
- self._test_model_compiles(new_tpu_model_ops)
+ logging.info(
+ 'New input shapes; (re-)compiling: mode=%s '
+ '(# of cores %d), %s', self.execution_mode,
+ self._tpu_assignment.num_towers, input_specs)
+ new_tpu_model_ops = self._specialize_model(input_specs,
+ infeed_manager)
+ self._compilation_cache[shape_key] = new_tpu_model_ops
+ self._test_model_compiles(new_tpu_model_ops)
return self._compilation_cache[shape_key]
@@ -970,15 +1121,28 @@ class TPUFunction(object):
# Note: this condition is possible during the prologue or epilogue of the
# pipelined loop.
return None, None
- # Strip sample weight from inputs
+
+ if (self.model.uses_learning_phase and
+ not isinstance(K.learning_phase(), int)):
+ # Remove the learning_phase flag at the end. We currently hard code the
+ # learning_phase in TPUFunction.
+ assert isinstance(inputs[-1], int), (
+ 'Expect the final element be learning_phase flag. Got {}'.format(
+ inputs[-1]))
+ inputs = inputs[:-1]
+
if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
self.execution_mode == model_fn_lib.ModeKeys.EVAL):
+ # Strip sample weight from inputs.
input_tensors = self.model._feed_inputs + self.model._feed_targets
- inputs = inputs[:len(input_tensors)]
- return input_tensors, inputs
else:
input_tensors = self.model._feed_inputs
- return input_tensors, inputs
+
+ inputs = inputs[:len(input_tensors)]
+ input_tensors, inputs = (
+ _inject_tpu_inputs_for_infeed(
+ self._tpu_assignment, self.execution_mode, input_tensors, inputs))
+ return input_tensors, inputs
def _process_outputs(self, outfeed_outputs):
"""Processes the outputs of a model function execution.
@@ -1038,11 +1202,10 @@ class TPUFunction(object):
# Initialize our TPU weights on the first compile.
self.model._initialize_weights(self._cloned_model)
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
- tpu_model_ops.outfeed_op
- ], infeed_dict)
+ _, _, outfeed_outputs = K.get_session().run([
+ tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
+ tpu_model_ops.outfeed_op
+ ], infeed_dict)
return self._process_outputs(outfeed_outputs)
def pipeline_run(self, cur_step_inputs, next_step_inputs):
@@ -1074,8 +1237,8 @@ class TPUFunction(object):
next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
- if (next_step_infeed_manager is not None
- and cur_step_infeed_manager is not None):
+ if (next_step_infeed_manager is not None and
+ cur_step_infeed_manager is not None):
assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
next_input_tensors, next_step_inputs = (
@@ -1100,14 +1263,12 @@ class TPUFunction(object):
infeed_dict = None
if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
- cur_input_specs = cur_infeed_instance.make_input_specs(
- cur_input_tensors)
+ cur_input_specs = cur_infeed_instance.make_input_specs(cur_input_tensors)
cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
cur_input_specs, cur_step_infeed_manager)
- if (next_infeed_instance
- and next_input_tensors
- and next_step_infeed_manager):
+ if (next_infeed_instance and next_input_tensors and
+ next_step_infeed_manager):
next_input_specs = next_infeed_instance.make_input_specs(
next_input_tensors)
next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
@@ -1118,26 +1279,24 @@ class TPUFunction(object):
self.model._initialize_weights(self._cloned_model)
if next_tpu_model_ops and cur_tpu_model_ops:
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
- cur_tpu_model_ops.outfeed_op
- ], infeed_dict)
+ _, _, outfeed_outputs = K.get_session().run([
+ next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
+ cur_tpu_model_ops.outfeed_op
+ ], infeed_dict)
return self._process_outputs(outfeed_outputs)
+
if cur_tpu_model_ops:
- with self.model.tpu_session() as session:
- _, outfeed_outputs = session.run([
- cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
+ _, outfeed_outputs = K.get_session().run(
+ [cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
return self._process_outputs(outfeed_outputs)
+
if next_tpu_model_ops:
- with self.model.tpu_session() as session:
- session.run(next_tpu_model_ops.infeed_op, infeed_dict)
+ K.get_session().run(next_tpu_model_ops.infeed_op, infeed_dict)
return None
raise RuntimeError('Internal error: both current & next tpu_model_ops '
'were None')
-
class KerasTPUModel(models.Model):
"""TPU compatible Keras model wrapper."""
@@ -1164,8 +1323,6 @@ class KerasTPUModel(models.Model):
self._tpu_model = None
self._tpu_weights_initialized = False
- self._session = tpu_session(cluster_resolver)
-
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
if self._cpu_model.optimizer:
@@ -1202,15 +1359,20 @@ class KerasTPUModel(models.Model):
if target_tensors:
raise ValueError('target_tensors is not supported for TPU execution.')
+ self._cpu_model.compile(
+ _clone_optimizer(optimizer),
+ loss,
+ _clone_metrics(metrics),
+ loss_weights,
+ sample_weight_mode,
+ _clone_metrics(weighted_metrics),
+ target_tensors,
+ **kwargs)
+
super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights,
sample_weight_mode, weighted_metrics,
target_tensors, **kwargs)
- if not self._cpu_model.optimizer:
- self._cpu_model.compile(optimizer, loss, metrics, loss_weights,
- sample_weight_mode, weighted_metrics,
- target_tensors, **kwargs)
-
def fit(self,
x=None,
y=None,
@@ -1243,8 +1405,8 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess,\
- ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
+ with ops.device('/job:%s/device:CPU:0' %
+ self._tpu_assignment.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@@ -1252,8 +1414,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -1269,26 +1431,24 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(validation_data):
- with self.tpu_session() as sess:
- dataset = validation_data()
- if validation_steps is None:
- raise ValueError('When using tf.data as validation for a model, you '
- 'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- val_x = infeed_manager.dummy_x
- val_y = infeed_manager.dummy_y
- infeed_managers.append((val_x, infeed_manager))
- validation_data = (val_x, val_y)
+ dataset = validation_data()
+ if validation_steps is None:
+ raise ValueError('When using tf.data as validation for a model, you '
+ 'should specify the validation_steps argument.')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ val_x = infeed_manager.dummy_x
+ val_y = infeed_manager.dummy_y
+ infeed_managers.append((val_x, infeed_manager))
+ validation_data = (val_x, val_y)
self._numpy_to_infeed_manager_list = infeed_managers
try:
if not kwargs.get('_pipeline', True):
- logging.info(
- 'Running non-pipelined training loop (`_pipeline=%s`).',
- kwargs['_pipeline'])
+ logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
kwargs.pop('_pipeline')
return super(KerasTPUModel, self).fit(
x,
@@ -1344,50 +1504,32 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess:
- dataset = x()
- if steps is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
- sess)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
self._numpy_to_infeed_manager_list = infeed_managers
try:
- return super(KerasTPUModel, self).evaluate(
- x,
- y,
- batch_size,
- verbose,
- sample_weight,
- steps)
+ return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
+ sample_weight, steps)
finally:
self._numpy_to_infeed_manager_list = []
- def _pipeline_fit(self,
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs):
+ def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
+ validation_split, validation_data, shuffle, class_weight,
+ sample_weight, initial_epoch, steps_per_epoch,
+ validation_steps, **kwargs):
# Similar to super.fit(...), but modified to support software pipelining.
# Backwards compatibility
@@ -1415,13 +1557,8 @@ class KerasTPUModel(models.Model):
# Prepare validation data
val_x, val_y, val_sample_weights = self._prepare_validation_data(
- validation_data,
- validation_split,
- validation_steps,
- x,
- y,
- sample_weights,
- batch_size)
+ validation_data, validation_split, validation_steps, x, y,
+ sample_weights, batch_size)
return self._pipeline_fit_loop(
x,
y,
@@ -1594,8 +1731,8 @@ class KerasTPUModel(models.Model):
for i in indices_for_conversion_to_dense:
ins_batch[i] = ins_batch[i].toarray()
- outs = f.pipeline_run(cur_step_inputs=ins_last_batch,
- next_step_inputs=ins_batch)
+ outs = f.pipeline_run(
+ cur_step_inputs=ins_last_batch, next_step_inputs=ins_batch)
ins_last_batch = ins_batch
if batch_index == 0:
@@ -1667,8 +1804,8 @@ class KerasTPUModel(models.Model):
next_step_inputs = ins
else:
next_step_inputs = None
- outs = f.pipeline_run(cur_step_inputs=ins,
- next_step_inputs=next_step_inputs)
+ outs = f.pipeline_run(
+ cur_step_inputs=ins, next_step_inputs=next_step_inputs)
except errors.OutOfRangeError:
logging.warning('Your dataset iterator ran out of data; '
'interrupting training. Make sure that your '
@@ -1688,25 +1825,21 @@ class KerasTPUModel(models.Model):
break
if do_validation:
- val_outs = training_arrays.test_loop(self,
- val_inputs,
- val_targets,
- sample_weights=val_sample_weights,
- steps=validation_steps,
- verbose=0)
+ val_outs = training_arrays.test_loop(
+ self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for l, o in zip(self.metrics_names, val_outs):
epoch_logs['val_' + l] = o
- def _prepare_validation_data(self,
- validation_data,
- validation_split,
- validation_steps,
- x,
- y,
- sample_weights,
+ def _prepare_validation_data(self, validation_data, validation_split,
+ validation_steps, x, y, sample_weights,
batch_size):
"""Prepares the validation dataset.
@@ -1764,8 +1897,10 @@ class KerasTPUModel(models.Model):
x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
- sample_weights, val_sample_weights = (slice_arrays(
- sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
+ sample_weights, val_sample_weights = (
+ slice_arrays(sample_weights, 0, split_at),
+ slice_arrays(sample_weights, split_at)
+ )
elif validation_steps:
val_x = []
val_y = []
@@ -1777,11 +1912,20 @@ class KerasTPUModel(models.Model):
return val_x, val_y, val_sample_weights
+ @property
+ def optimizer(self):
+ if self._tpu_model:
+ return self._tpu_model.optimizer
+ return self._cpu_model.optimizer
+
+ @optimizer.setter
+ def optimizer(self, optimizer):
+ self._optimizer = optimizer
+
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
- self,
- model_fn_lib.ModeKeys.TRAIN,
+ self, model_fn_lib.ModeKeys.TRAIN,
tpu_assignment=self._tpu_assignment)
return self.train_function
@@ -1816,18 +1960,48 @@ class KerasTPUModel(models.Model):
self._tpu_weights_initialized = True
weights = self._cpu_model.get_weights()
- with self.tpu_session():
- logging.info('Setting weights on TPU model.')
- cloned_model.set_weights(weights)
+
+ if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
+ cpu_optimizer_config = {}
+ else:
+ cpu_optimizer_config = self.cpu_optimizer.get_config()
+
+ logging.info('Setting weights on TPU model.')
+ cloned_model.set_weights(weights)
+ for k, v in six.iteritems(cpu_optimizer_config):
+ opt_var = getattr(self._tpu_model.optimizer, k)
+ if isinstance(opt_var, variables.Variable):
+ logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var))
+ K.get_session().run(opt_var.assign(v))
+ else:
+ logging.warning('Cannot update non-variable config: %s', k)
+
+ @property
+ def cpu_optimizer(self):
+ return self._cpu_model.optimizer
def sync_to_cpu(self):
"""Copy weights from the CPU, returning a synchronized CPU model."""
- if self._tpu_weights_initialized:
- with self.tpu_session():
- logging.info('Copying TPU weights to the CPU')
- tpu_weights = self._tpu_model.get_weights()
+ if not self._tpu_weights_initialized:
+ return self._cpu_model
+
+ logging.info('Copying TPU weights to the CPU')
+ tpu_weights = self._tpu_model.get_weights()
- self._cpu_model.set_weights(tpu_weights)
+ # TFOptimizers have no configurable options
+ if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
+ tpu_optimizer_config = {}
+ else:
+ tpu_optimizer_config = self._tpu_model.optimizer.get_config()
+
+ self._cpu_model.set_weights(tpu_weights)
+ for k, v in six.iteritems(tpu_optimizer_config):
+ logging.info('TPU -> CPU %s: %s', k, v)
+ opt_var = getattr(self.cpu_optimizer, k)
+ if isinstance(opt_var, variables.Variable):
+ K.get_session().run(opt_var.assign(v))
+ else:
+ logging.warning('Cannot update non-variable config: %s', k)
return self._cpu_model
@@ -1848,26 +2022,6 @@ class KerasTPUModel(models.Model):
self._cpu_model.set_weights(weights)
self._tpu_weights_initialized = False
- @contextlib.contextmanager
- def tpu_session(self):
- """Yields a TPU session and sets it as the default Keras session."""
- with self._session.graph.as_default():
- default_session = K.get_session()
- # N.B. We have to call `K.set_session()` AND set our session as the
- # TF default. `K.get_session()` surprisingly does not return the value
- # supplied by K.set_session otherwise.
- K.set_session(self._session)
- with self._session.as_default():
- yield self._session
- K.set_session(default_session)
-
- def shutdown(self):
- # TODO(b/111364423): Actually shut down the system.
- logging.info('Skipping shutting down TPU system.')
- # with self.tpu_session() as session:
- # session.run(tpu.shutdown_system())
- self._session.close()
-
# pylint: disable=bad-continuation
def _validate_shapes(model):
@@ -1908,7 +2062,9 @@ Output shape: %(output_shape)s
@experimental
def tpu_model(model, strategy=None):
- """Copy `model` along with weights to the TPU. Returns a TPU model.
+ """Copy `model` along with weights to the TPU.
+
+ Returns a TPU model.
Usage:
```
@@ -1923,21 +2079,16 @@ def tpu_model(model, strategy=None):
model.compile(
optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
...)
- model.shutdown()
```
Args:
- model: A `KerasTPUModel`.
+ model: A `tf.keras.Model` instance.
strategy: `TPUDistributionStrategy`. The strategy to use for replicating
- model across multiple TPU cores.
+ model across multiple TPU cores.
Returns:
A new `KerasTPUModel` instance.
"""
- # Force initialization of the CPU model.
- model.get_weights()
- model.reset_states()
-
_validate_shapes(model)
# TODO(xiejw): Validate TPU model. TPUModel only?
# TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
@@ -1951,4 +2102,34 @@ def tpu_model(model, strategy=None):
'`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
'Got: {}'.format(type(strategy)))
- return KerasTPUModel(cpu_model=model, strategy=strategy)
+ # If the model has already been initialized, grab the optimizer configuration
+ # and model weights before entering the TPU session.
+ if model.optimizer:
+ if (isinstance(model.optimizer, keras_optimizers.Optimizer) and not
+ isinstance(model.optimizer, keras_optimizers.TFOptimizer)):
+ optimizer_config = model.optimizer.get_config()
+ else:
+ optimizer_config = None
+ model_weights = model.get_weights()
+ else:
+ model_weights = None
+
+ setup_tpu_session(strategy._tpu_cluster_resolver)
+
+ # Force initialization of the CPU model in the TPU session.
+ cpu_model = models.clone_model(model)
+ if model.optimizer:
+ cpu_model.compile(
+ _clone_optimizer(model.optimizer, optimizer_config),
+ model.loss,
+ _clone_metrics(model.metrics),
+ model.loss_weights,
+ model.sample_weight_mode,
+ _clone_metrics(model.weighted_metrics),
+ )
+
+ if model_weights:
+ cpu_model.set_weights(model_weights)
+ cpu_model.reset_states()
+
+ return KerasTPUModel(cpu_model=cpu_model, strategy=strategy)
diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py
index 1fb26e701a..ab89c6aa8c 100644
--- a/tensorflow/contrib/tpu/python/tpu/topology.py
+++ b/tensorflow/contrib/tpu/python/tpu/topology.py
@@ -112,6 +112,11 @@ class Topology(object):
return self._mesh_shape
@property
+ def mesh_rank(self):
+ """Returns the number of dimensions in the mesh."""
+ return len(self._mesh_shape)
+
+ @property
def device_coordinates(self):
"""Describes the mapping from TPU devices to topology coordinates.
@@ -125,6 +130,16 @@ class Topology(object):
"""
return self._device_coordinates
+ @property
+ def num_tasks(self):
+ """Returns the number of TensorFlow tasks in the TPU slice."""
+ return self._device_coordinates.shape[0]
+
+ @property
+ def num_tpus_per_task(self):
+ """Returns the number of TPU devices per task in the TPU slice."""
+ return self._device_coordinates.shape[1]
+
def serialized(self):
"""Returns the serialized form of the topology."""
if self._serialized is None:
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 0f9f7cd91b..712b02ff0d 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.compat import compat as api_compat
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -75,7 +76,7 @@ def initialize_system(embedding_config=None, job=None):
"""Initializes a distributed TPU system for use with TensorFlow.
Args:
- embedding_config: If not None, an `EmbeddingLayerConfiguration` proto
+ embedding_config: If not None, a `TPUEmbeddingConfiguration` proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
job: The job (the XXX in TensorFlow device specification /job:XXX) that
@@ -558,10 +559,17 @@ def split_compile_and_replicate(computation,
"topology":
device_assignment.topology.serialized(),
"device_assignment":
- device_assignment.core_assignment.flatten().tolist(),
- "computation_shape":
- device_assignment.computation_shape.tolist()
+ device_assignment.core_assignment.flatten().tolist()
}
+ # TODO(phawkins): remove this case after the forward compatibility window
+ # expires on 2018-10-5.
+ if api_compat.forward_compatible(2018, 10, 5):
+ metadata_kwargs["num_cores_per_replica"] = (
+ device_assignment.num_cores_per_replica)
+ else:
+ metadata_kwargs["computation_shape"] = [
+ device_assignment.num_cores_per_replica
+ ]
if ((not isinstance(inputs, list)) or
any(not isinstance(inp, (list, tuple)) for inp in inputs)):
@@ -840,8 +848,12 @@ def shard(computation,
if num_shards <= 0:
raise ValueError("num_shards must be a positive integer.")
+ inputs = [] if inputs is None else inputs
+ if not isinstance(inputs, list):
+ raise TypeError("tpu.shard()'s inputs must be a list of Tensors or None.")
+
# Converts inputs to Tensors.
- inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs]
+ inputs = [ops.convert_to_tensor(x) for x in inputs]
if input_shard_axes is None:
input_shard_axes = [0] * len(inputs)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 18e0abdda2..9f8d147068 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -32,7 +32,6 @@ from tensorflow.python.platform import tf_logging as logging
_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV
_SERVICE_KEY = run_config_lib._SERVICE_KEY
_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'
-_NUM_CORES_PER_HOST = 8
# pylint: enable=protected-access
@@ -103,7 +102,7 @@ class TPUConfig(
input mode.
Raises:
- ValueError: If `num_cores_per_replica` is not 1, 2, 4 or 8.
+ ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8 or 16.
"""
def __new__(cls,
@@ -139,9 +138,9 @@ class TPUConfig(
# Check num_cores_per_replica
if num_cores_per_replica is not None:
- if num_cores_per_replica not in [1, 2, 4, 8]:
+ if num_cores_per_replica not in [1, 2, 4, 8, 16]:
raise ValueError(
- 'num_cores_per_replica must be 1, 2, 4, or 8; got {}'.format(
+ 'num_cores_per_replica must be 1, 2, 4, 8, or 16; got {}'.format(
str(num_cores_per_replica)))
# per_host_input_for_training may be True, False, or integer in [1..3].
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
index 2326fe97a8..b2fe0a6888 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config_test.py
@@ -86,7 +86,7 @@ class TPURunConfigTest(test.TestCase):
def test_fail_with_invalid_num_cores_per_replica(self):
with self.assertRaisesRegexp(
- ValueError, 'num_cores_per_replica must be 1, 2, 4, or 8;'
+ ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, or 16;'
' got 7'):
tpu_config_lib.TPUConfig(num_cores_per_replica=7)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 19359cb612..7cfb6c38fa 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -35,7 +35,8 @@ _NUM_CORES_TO_COMPUTATION_SHAPE = {
1: [1, 1, 1],
2: [1, 1, 2],
4: [1, 2, 2],
- 8: [2, 2, 2]
+ 8: [2, 2, 2],
+ 16: [4, 2, 2],
}
@@ -117,6 +118,11 @@ class TPUContext(object):
return self._internal_ctx.num_hosts
@property
+ def current_host(self):
+ """The current host index for the TPU system."""
+ return self._invocation_index
+
+ @property
def num_of_replicas_per_host(self):
"""The number of replicas for each host."""
if self._internal_ctx.model_parallelism_enabled:
@@ -298,6 +304,7 @@ class _InternalTPUContext(object):
@property
def num_of_replicas_per_host(self):
+ """Return the number of replicas per host."""
if self.model_parallelism_enabled:
return self.num_replicas // self.num_hosts
else:
@@ -538,8 +545,8 @@ class _InternalTPUContext(object):
"""
if self.model_parallelism_enabled:
# We put both enqueue/dequeue ops at tpu.core(0) in each replica.
- replica = self.device_assignment.lookup_replicas(
- host_id, (0, 0, 0))[shard_index_in_host]
+ replica = self.device_assignment.lookup_replicas(host_id,
+ 0)[shard_index_in_host]
return self.device_assignment.tpu_ordinal(replica=replica)
else:
return shard_index_in_host % self.num_of_cores_per_host
@@ -580,6 +587,17 @@ class _InternalTPUContext(object):
raise ValueError(message)
+ if self._config.tpu_config.num_cores_per_replica:
+ num_cores_per_replica = self._config.tpu_config.num_cores_per_replica
+ num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host
+ if num_cores_per_replica > num_cores_per_host:
+ raise ValueError(
+ 'The num of cores required by the model parallelism, specified by '
+ 'TPUConfig.num_cores_per_replica, is larger than the '
+ 'num_cores_per_host. num_cores_per_replica: {}, '
+ 'num_cores_per_host: {}'.format(num_cores_per_replica,
+ num_cores_per_host))
+
if mode == model_fn_lib.ModeKeys.TRAIN:
if (self._train_batch_size % num_replicas != 0 and
not self.is_input_broadcast_with_iterators()):
@@ -599,8 +617,8 @@ class _InternalTPUContext(object):
.format(self._eval_batch_size, num_replicas))
if num_hosts > 1 and not self.is_input_broadcast_with_iterators():
raise ValueError(
- 'TPUEstimator.evaluate should be running on single TPU worker. '
- 'got {}.'.format(num_hosts))
+ 'TPUEstimator.evaluate should be running on single TPU'
+ ' instead of a Pod.')
else:
assert mode == model_fn_lib.ModeKeys.PREDICT
if self._predict_batch_size is None:
@@ -685,7 +703,7 @@ def _get_tpu_context(config, train_batch_size, eval_batch_size,
config.tpu_config.num_cores_per_replica is None):
logging.warning(
'Setting TPUConfig.num_shards==1 is an unsupported behavior. '
- 'Please fix as soon as possible (leaving num_shards as None.')
+ 'Please fix as soon as possible (leaving num_shards as None.)')
return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,
predict_batch_size, use_tpu)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 1ff04f5c26..23c54511ca 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -1774,18 +1774,19 @@ class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
summary_writer=summary_writer)
def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
- global_step_per_sec = elapsed_steps / elapsed_time
- examples_per_sec = self._batch_size * global_step_per_sec
+ global_steps_per_sec = elapsed_steps / elapsed_time
+ examples_per_sec = self._batch_size * global_steps_per_sec
if self._summary_writer is not None:
global_step_summary = Summary(value=[
- Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec)
+ Summary.Value(tag='global_steps/sec',
+ simple_value=global_steps_per_sec)
])
example_summary = Summary(value=[
Summary.Value(tag='examples/sec', simple_value=examples_per_sec)
])
self._summary_writer.add_summary(global_step_summary, global_step)
self._summary_writer.add_summary(example_summary, global_step)
- logging.info('global_step/sec: %g', global_step_per_sec)
+ logging.info('global_steps/sec: %g', global_steps_per_sec)
logging.info('examples/sec: %g', examples_per_sec)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
index d9c77a3ea1..e75a09492e 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
@@ -765,9 +765,8 @@ class _PartitionedInfeedQueue(InfeedQueue):
zip(per_host_sharded_inputs[replica_index], inputs_part_dims_flat)
]
- for core_index in xrange(self._device_assignment.num_cores_per_replica):
+ for logical_core in xrange(self._device_assignment.num_cores_per_replica):
# Places different partitions to different logic cores.
- logical_core = self._get_logical_core(core_index)
replica_id = self._device_assignment.lookup_replicas(
self._host_id, logical_core)[replica_index]
ordinal = self._device_assignment.tpu_ordinal(
@@ -784,7 +783,7 @@ class _PartitionedInfeedQueue(InfeedQueue):
inputs=infeed_inputs,
shapes=[x.shape for x in infeed_inputs],
name="enqueue/replica_{0}/input_{1}".format(
- replica_index, core_index),
+ replica_index, logical_core),
device_ordinal=ordinal))
return per_host_enqueue_ops
@@ -890,20 +889,3 @@ class _PartitionedInfeedQueue(InfeedQueue):
return nest.map_structure_up_to(
dequeues, self._tag_sharding_attribute_for_dequeued_tensor, dequeues,
dims)
-
- def _get_logical_core(self, core_index):
- """Maps the core index to the 3D coordinate within replica.
-
- The lowest dimension number in computation_shape is the slowest varying
- dimension (most major).
-
- Args:
- core_index: An integer represents the core index within replcia.
-
- Returns:
- A tuple with three integers which represents the 3D coordinate.
- """
- computation_shape = self._device_assignment.computation_shape
- return (core_index // (computation_shape[1] * computation_shape[2]),
- core_index % (computation_shape[1] * computation_shape[2]) //
- computation_shape[2], core_index % computation_shape[2])
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py
index de16e3b157..0c7a38dbbb 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_function.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py
@@ -63,10 +63,9 @@ def check_function_argument_count(func, input_arity, infeed_queue):
"""Validate the number of input arguments to a tpu function.
Args:
- func: the Python function that will be called to generate the body
- of a TPUFunction.
- input_arity: the number of explicit arguments supplied by the
- caller.
+ func: the Python function that will be called to generate the body of an XLA
+ computation graph.
+ input_arity: the number of explicit arguments supplied by the caller.
infeed_queue: if not None, the infeed queue that will supply
additional arguments to the function.
@@ -103,4 +102,3 @@ def check_function_argument_count(func, input_arity, infeed_queue):
# Since there are varargs, func can accept any number of arguments
# greater than the minimum.
return None
-
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
index f46d03209c..8896a95327 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -29,7 +29,7 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import nest as tf_nest
-class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
+class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that prepends a queue to another `Dataset`.
A vector of handles to the queue is returned as the first component of
@@ -39,7 +39,7 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
"""Initialize `PrependFromQueueAndPaddedBatchDataset`."""
- super(_PrependFromQueueAndPaddedBatchDataset, self).__init__()
+ super(_PrependFromQueueAndPaddedBatchDataset, self).__init__(input_dataset)
if sparse.any_sparse(input_dataset.output_classes):
raise TypeError(
"Batching of padded sparse tensors is not currently supported")
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
index 3cb5e61fac..2784bf124c 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
-#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/pool_allocator.h"
@@ -29,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@@ -256,74 +256,41 @@ void MRDeleter(ibv_mr* mr) {
}
}
-// TODO(byronyi): remove this class and its registration when the default
-// cpu_allocator() returns visitable allocator, or cpu_allocator() is no
-// longer in use.
-class BFCRdmaAllocator : public BFCAllocator {
- public:
- BFCRdmaAllocator()
- : BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36,
- true, "cpu_rdma_bfc") {}
-};
-class BFCRdmaAllocatorFactory : public AllocatorFactory {
- public:
- Allocator* CreateAllocator() { return new BFCRdmaAllocator; }
-
- SubAllocator* CreateSubAllocator(int numa_node) {
- return new BasicCPUAllocator(numa_node);
- }
-};
-
-REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory);
-
void RdmaMgr::InitAllocators() {
- RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_;
+ static std::once_flag flag;
+ std::call_once(
+ flag, [this]() { RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_; });
+}
- Allocator* allocators[] = {
-#if GOOGLE_CUDA
- GPUProcessState::singleton()->GetCUDAHostAllocator(0),
-#endif // GOOGLE_CUDA
- ProcessState::singleton()->GetCPUAllocator(0),
- cpu_allocator(),
+/*static*/ void RdmaMgr::RegMemVisitors() {
+ SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
+ size_t num_bytes) {
+ RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+ ptr, num_bytes, strings::StrCat("CPU:", numa_node));
+ };
+ SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
+ size_t num_bytes) {
+ RdmaMemoryMgr::Singleton().EvictMemoryRegion(ptr, num_bytes);
};
- using namespace std::placeholders;
-
- std::set<Allocator*> instrumented_;
-
- // Host memory allocators
- for (Allocator* allocator : allocators) {
- VisitableAllocator::Visitor alloc_visitor =
- std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
- &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name());
- VisitableAllocator::Visitor free_visitor = std::bind(
- &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
-
- auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
- CHECK(visitable_allocator)
- << "is not visitable for instrumentation" << allocator->Name();
- // Make sure we don't instrument the same allocator twice
- if (instrumented_.find(allocator) == std::end(instrumented_)) {
- visitable_allocator->AddAllocVisitor(alloc_visitor);
- visitable_allocator->AddFreeVisitor(free_visitor);
- instrumented_.insert(allocator);
- LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
- }
- }
+ ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
+ ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
#if GOOGLE_CUDA
if (IsGDRAvailable()) {
// Note we don't free allocated GPU memory so there is no free visitor
int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
- char buf[8];
- sprintf(buf, "gpu");
- VisitableAllocator::Visitor cuda_alloc_visitor =
- std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
- &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf));
-
+ SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
+ size_t num_bytes) {
+ RdmaMemoryMgr::Singleton().InsertMemoryRegion(
+ ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
+ };
GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
cuda_alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
+ alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
}
#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h
index 9fffc335bb..74b92cc9a6 100644
--- a/tensorflow/contrib/verbs/rdma_mgr.h
+++ b/tensorflow/contrib/verbs/rdma_mgr.h
@@ -39,6 +39,7 @@ class RdmaMgr {
void SetupChannels();
bool ConnectivityCheck();
void InitAllocators();
+ static void RegMemVisitors();
const string& local_worker() { return local_worker_; }
private:
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
index 1a0b5028fe..5b72b1604a 100644
--- a/tensorflow/contrib/verbs/verbs_server_lib.cc
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -76,8 +76,13 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
return Status::OK();
}
+namespace {
+std::once_flag reg_mem_visitors_call;
+} // namespace
+
Status VerbsServer::Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func) {
+ std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); });
Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
{
mutex_lock l(mu_);
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 79ad3b8e54..bc0bfb793c 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -85,11 +85,12 @@ load(
"tf_cc_tests",
"tf_copts",
"tf_cuda_library",
+ "tf_features_nomodules_if_android",
"tf_gen_op_libs",
"tf_generate_proto_text_sources",
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
- "tf_features_nomodules_if_android",
+ "transitive_hdrs",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -120,16 +121,16 @@ load(
"tf_additional_libdevice_srcs",
"tf_additional_minimal_lib_srcs",
"tf_additional_mpi_lib_defines",
- "tf_additional_proto_hdrs",
"tf_additional_proto_compiler_hdrs",
+ "tf_additional_proto_hdrs",
"tf_additional_proto_srcs",
"tf_additional_test_deps",
"tf_additional_test_srcs",
"tf_additional_verbs_lib_defines",
"tf_jspb_proto_library",
"tf_kernel_tests_linkstatic",
- "tf_lib_proto_parsing_deps",
"tf_lib_proto_compiler_deps",
+ "tf_lib_proto_parsing_deps",
"tf_nano_proto_library",
"tf_platform_hdrs",
"tf_platform_srcs",
@@ -143,6 +144,7 @@ load(
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
+ "if_dynamic_kernels",
"if_static",
"tf_cuda_tests_tags",
)
@@ -168,6 +170,7 @@ COMMON_PROTO_SRCS = [
"example/example.proto",
"example/feature.proto",
"framework/allocation_description.proto",
+ "framework/api_def.proto",
"framework/attr_value.proto",
"framework/cost_graph.proto",
"framework/device_attributes.proto",
@@ -179,7 +182,6 @@ COMMON_PROTO_SRCS = [
"framework/log_memory.proto",
"framework/node_def.proto",
"framework/op_def.proto",
- "framework/api_def.proto",
"framework/reader_base.proto",
"framework/remote_fused_graph_execute_info.proto",
"framework/resource_handle.proto",
@@ -299,6 +301,7 @@ filegroup(
name = "platform_base_hdrs",
srcs = [
"platform/byte_order.h",
+ "platform/cord.h",
"platform/env_time.h",
"platform/logging.h",
"platform/macros.h",
@@ -720,6 +723,7 @@ cc_library(
name = "abi",
srcs = ["platform/abi.cc"],
hdrs = ["platform/abi.h"],
+ deps = [":platform_base"],
)
cc_library(
@@ -874,7 +878,6 @@ tf_cuda_library(
"util/bcast.h",
"util/cuda_kernel_helper.h",
"util/device_name_utils.h",
- "util/env_var.h",
"util/events_writer.h",
"util/example_proto_fast_parsing.h",
"util/example_proto_helper.h",
@@ -1065,7 +1068,6 @@ tf_gen_op_libs(
"spectral_ops",
"state_ops",
"stateless_random_ops",
- "string_ops",
"summary_ops",
"training_ops",
],
@@ -1073,6 +1075,13 @@ tf_gen_op_libs(
tf_gen_op_libs(
op_lib_names = [
+ "string_ops",
+ ],
+ deps = ["@com_google_absl//absl/strings"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
"array_ops",
],
deps = [":protos_all_cc"],
@@ -1284,8 +1293,8 @@ cc_library(
# This includes implementations of all kernels built into TensorFlow.
cc_library(
- name = "all_kernels",
- visibility = ["//visibility:public"],
+ name = "all_kernels_statically_linked",
+ visibility = ["//visibility:private"],
deps = [
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:audio",
@@ -1328,6 +1337,7 @@ cc_library(
"//tensorflow/core/kernels:rpc_op",
"//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
+ "//tensorflow/core/kernels:searchsorted_op",
"//tensorflow/core/kernels:set_kernels",
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",
@@ -1363,6 +1373,15 @@ cc_library(
]),
)
+cc_library(
+ name = "all_kernels",
+ visibility = ["//visibility:public"],
+ deps = if_dynamic_kernels(
+ [],
+ otherwise = [":all_kernels_statically_linked"],
+ ),
+)
+
tf_cuda_library(
name = "tensorflow_opensource",
copts = tf_copts(),
@@ -1426,9 +1445,11 @@ cc_library(
":test",
":testlib_ops",
"//tensorflow/cc:scope",
+ "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
+ "//tensorflow/core/kernels:random_ops",
],
)
@@ -1918,6 +1939,13 @@ tf_pyclif_proto_library(
)
tf_pyclif_proto_library(
+ name = "protobuf/config_pyclif",
+ proto_lib = ":protos_all_cc",
+ proto_srcfile = "protobuf/config.proto",
+ visibility = ["//visibility:public"],
+)
+
+tf_pyclif_proto_library(
name = "protobuf/device_properties_pyclif",
proto_lib = ":protos_all_cc",
proto_srcfile = "protobuf/device_properties.proto",
@@ -2056,6 +2084,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [
"platform/snappy.h",
"platform/tensor_coding.h",
"platform/tracing.h",
+ "util/env_var.h",
]
# Replicated for lib_internal and lib_internal_impl.
@@ -2083,6 +2112,7 @@ cc_library(
deps = tf_additional_lib_deps() + [
"@com_google_absl//absl/strings",
"//third_party/eigen3",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/core/platform/default/build_config:platformlib",
] + if_static([":lib_internal_impl"]),
)
@@ -2095,6 +2125,7 @@ cc_library(
"platform/*.cc",
"platform/profile_utils/**/*.cc",
"framework/resource_handle.cc",
+ "util/env_var.cc",
],
exclude = [
"**/*test*",
@@ -2274,6 +2305,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:jpeg",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
@@ -2306,6 +2338,7 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:gif",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)
@@ -2450,7 +2483,6 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/unique_tensor_references.h",
"framework/variant.h",
"util/command_line_flags.h",
- "util/env_var.h",
"util/equal_graph_def.h",
"util/presized_cuckoo_map.h",
"util/tensor_slice_set.h",
@@ -2479,7 +2511,12 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_internal_headers_lib",
- includes = ["../../external/com_google_absl"],
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
deps = [
":lib",
":lib_internal",
@@ -2526,6 +2563,7 @@ tf_cuda_library(
"util/memmapped_file_system_writer.*",
"util/stats_calculator.*",
"util/version_info.cc",
+ "util/env_var.cc",
],
) + select({
"//tensorflow:windows": [],
@@ -2564,11 +2602,12 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_headers_lib",
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
extra_deps = [
- # ABSL headers get dropped, so we add them back here.
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
],
- includes = ["../../external/com_google_absl"],
visibility = ["//visibility:public"],
deps = [
":framework",
@@ -2578,7 +2617,12 @@ cc_header_only_library(
cc_header_only_library(
name = "stream_executor_headers_lib",
- includes = ["../../external/com_google_absl"],
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
visibility = ["//visibility:public"],
deps = [
":stream_executor",
@@ -2769,8 +2813,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
- "common_runtime/tracing_device.h",
- "common_runtime/visitable_allocator.h",
"common_runtime/process_state.h",
"common_runtime/pool_allocator.h",
"graph/gradients.h",
@@ -2966,12 +3008,16 @@ tf_cuda_library(
] + tf_additional_device_tracer_deps(),
)
-cc_library(
- name = "session_ref",
- srcs = ["common_runtime/session_ref.cc"],
- hdrs = ["common_runtime/session_ref.h"],
- copts = tf_copts(),
- deps = [":core_cpu_base"],
+tf_proto_library_cc(
+ name = "replay_log_proto",
+ srcs = ["protobuf/replay_log.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ ":master_proto",
+ ] + tf_additional_all_protos(),
+ visibility = [
+ "//tensorflow:internal",
+ ],
)
cc_library(
@@ -4708,6 +4754,18 @@ cc_library(
] + tf_additional_libdevice_deps(),
)
+transitive_hdrs(
+ name = "headers",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:stream_executor",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets go here (must be at the end).
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
new file mode 100644
index 0000000000..cdaeb5091c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
@@ -0,0 +1,34 @@
+op {
+ graph_op_name: "BoostedTreesBucketize"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensor each containing float values for a single feature.
+END
+ }
+ in_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+feature.
+END
+ }
+ out_arg {
+ name: "buckets"
+ description: <<END
+int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features.
+END
+ }
+ summary: "Bucketize each feature based on bucket boundaries."
+ description: <<END
+An op that returns a list of float tensors, where each tensor represents the
+bucketized values for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
new file mode 100644
index 0000000000..20da1295f6
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "BoostedTreesCreateQuantileStreamResource"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; Handle to quantile stream resource.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required approximation error of the stream resource.
+END
+ }
+ in_arg {
+ name: "num_streams"
+ description: <<END
+int; The number of streams managed by the resource that shares the same epsilon.
+END
+ }
+ attr {
+ name: "max_elements"
+ description : <<END
+int; The maximum number of data points that can be fed to the stream.
+END
+ }
+ summary: "Create the Resource for Quantile Streams."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
new file mode 100644
index 0000000000..ca111af312
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "BoostedTreesMakeQuantileSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensors each containing values for a single feature.
+END
+ }
+ in_arg {
+ name: "example_weights"
+ description: <<END
+float; Rank 1 Tensor with weights per instance.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required maximum approximation error.
+END
+ }
+ out_arg {
+ name: "summaries"
+ description: <<END
+float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+min_rank, max_rank) of a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+int; Inferred from the size of float_values.
+The number of float features.
+END
+ }
+ summary: "Makes the summary of quantiles for the batch."
+ description: <<END
+An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
new file mode 100644
index 0000000000..bbeecbf32b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
@@ -0,0 +1,22 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "summaries"
+ description: <<END
+string; List of Rank 2 Tensor each containing the summaries for a single feature.
+END
+ }
+ summary: "Add the quantile summaries to each quantile stream resource."
+ description: <<END
+An op that adds a list of quantile summaries to a quantile stream resource. Each
+summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
new file mode 100644
index 0000000000..2fd94efa10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
@@ -0,0 +1,31 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceFlush"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "num_buckets",
+ description: <<END
+int; approximate number of buckets unless using generate_quantiles.
+END
+ }
+ attr {
+ name: "generate_quantiles"
+ description: <<END
+bool; If True, the output will be the num_quantiles for each stream where the ith
+entry is the ith quantile of the input with an approximation error of epsilon.
+Duplicate values may be present.
+If False, the output will be the points in the histogram that we got which roughly
+translates to 1/epsilon boundaries and without any duplicates.
+Default to False.
+END
+ }
+ summary: "Flush the summaries for a quantile stream resource."
+ description: <<END
+An op that flushes the summaries for a quantile stream resource.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
new file mode 100644
index 0000000000..206672802f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
@@ -0,0 +1,27 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ out_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features to get bucket boundaries for.
+END
+ }
+ summary: "Generate the bucket boundaries for each feature based on accumulated summaries."
+ description: <<END
+An op that returns a list of float tensors for a quantile stream resource. Each
+tensor is Rank 1 containing bucket boundaries for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
new file mode 100644
index 0000000000..cb7786c051
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceHandleOp"
+ visibility: HIDDEN
+ summary: "Creates a handle to a BoostedTreesQuantileStreamResource."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
index e39213cbc7..440800704e 100644
--- a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
@@ -11,7 +11,8 @@ END
name: "record_defaults"
description: <<END
One tensor per column of the input record, with either a
-scalar default value for that column or empty if the column is required.
+scalar default value for that column or an empty vector if the column is
+required.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt
new file mode 100644
index 0000000000..3c8a455983
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt
@@ -0,0 +1,49 @@
+op {
+ graph_op_name: "ExtractVolumePatches"
+ in_arg {
+ name: "input"
+ description: <<END
+5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`.
+END
+ }
+ out_arg {
+ name: "patches"
+ description: <<END
+5-D Tensor with shape `[batch, out_planes, out_rows, out_cols,
+ksize_planes * ksize_rows * ksize_cols * depth]` containing patches
+with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized
+in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols`
+are the dimensions of the output patches.
+END
+ }
+ attr {
+ name: "ksizes"
+ description: <<END
+The size of the sliding window for each dimension of `input`.
+END
+ }
+ attr {
+ name: "strides"
+ description: <<END
+1-D of length 5. How far the centers of two consecutive patches are in
+`input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`.
+END
+ }
+ attr {
+ name: "padding"
+ description: <<END
+The type of padding algorithm to use.
+
+We specify the size-related attributes as:
+
+```python
+ ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
+ strides = [1, stride_planes, strides_rows, strides_cols, 1]
+```
+END
+ }
+ summary: <<END
+Extract `patches` from `input` and put them in the "depth" output
+dimension. 3D extension of `extract_image_patches`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
new file mode 100644
index 0000000000..758eeb96f0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; The reference to quantile stream resource handle.
+END
+ }
+ out_arg {
+ name: "is_initialized"
+ description: <<END
+bool; True if the resource is initialized, False otherwise.
+END
+ }
+ summary: "Checks whether a quantile stream has been initialized."
+ description: <<END
+An Op that checks if quantile stream resource is initialized.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
new file mode 100644
index 0000000000..5ce825ae04
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "LowerBound"
+ visibility: HIDDEN
+ in_arg {
+ name: "sorted_inputs"
+ description: <<END
+2-D Tensor where each row is ordered.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A `Tensor` with the same shape as `values`. It contains the first scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+ }
+ summary: "Applies lower_bound(sorted_search_values, values) along each row."
+ description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently. The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='left')`.
+
+The result is not a global index to the entire
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = LowerBound(sorted_sequence, values)
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
new file mode 100644
index 0000000000..171add16d4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "ModelDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ summary: "Identity transformation that models performance."
+ description: <<END
+Identity transformation that models performance.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt
new file mode 100644
index 0000000000..4b0a5d8f65
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "MultiDeviceIterator"
+ out_arg {
+ name: "handle"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "devices"
+ description: <<END
+A list of devices the iterator works across.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name
+across multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Creates a MultiDeviceIterator resource."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt
new file mode 100644
index 0000000000..adaacd8ab7
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "MultiDeviceIteratorFromStringHandle"
+ in_arg {
+ name: "string_handle"
+ description: <<END
+String representing the resource.
+END
+ }
+ out_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Generates a MultiDeviceIterator resource from its provided string handle."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt
new file mode 100644
index 0000000000..f9be9188cc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt
@@ -0,0 +1,41 @@
+op {
+ graph_op_name: "MultiDeviceIteratorGetNextFromShard"
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ in_arg {
+ name: "shard_num"
+ description: <<END
+Integer representing which shard to fetch data for.
+END
+ }
+ in_arg {
+ name: "incarnation_id"
+ description: <<END
+Which incarnation of the MultiDeviceIterator is running.
+END
+ }
+ out_arg {
+ name: "components"
+ description: <<END
+Result of the get_next on the dataset.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Gets next element for the provided shard number."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt
new file mode 100644
index 0000000000..6b54fa1307
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt
@@ -0,0 +1,30 @@
+op {
+ graph_op_name: "MultiDeviceIteratorInit"
+ in_arg {
+ name: "dataset"
+ description: <<END
+Dataset to be iterated upon.
+END
+ }
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIteratorResource.
+END
+ }
+ in_arg {
+ name: "max_buffer_size"
+ description: <<END
+The maximum size of the host side per device buffer to keep.
+END
+ }
+ out_arg {
+ name: "incarnation_id"
+ description: <<END
+An int64 indicating which incarnation of the MultiDeviceIterator
+is running.
+END
+ }
+ summary: "Initializes the multi device iterator with the given dataset."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt
new file mode 100644
index 0000000000..1f1fdf99b4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt
@@ -0,0 +1,17 @@
+op {
+ graph_op_name: "MultiDeviceIteratorToStringHandle"
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ out_arg {
+ name: "string_handle"
+ description: <<END
+A string representing the resource.
+END
+ }
+ summary: "Produces a string handle for the given MultiDeviceIterator."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..4cb8955dcb
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,19 @@
+op {
+ graph_op_name: "PrintV2"
+ in_arg {
+ name: "input"
+ description: <<END
+The string scalar to print.
+END
+ }
+ attr {
+ name: "output_stream"
+ description: <<END
+A string specifying the output stream or logging level to print to.
+END
+ }
+ summary: "Prints a string scalar."
+ description: <<END
+Prints a string scalar to the desired output_stream.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..a82dae9e48
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,38 @@
+op {
+ graph_op_name: "StringFormat"
+ in_arg {
+ name: "inputs"
+ description: <<END
+The list of tensors to format into the placeholder string.
+END
+ }
+
+ out_arg {
+ name: "output"
+ description: <<END
+= The resulting string scalar.
+END
+ }
+ attr {
+ name: "template"
+ description: <<END
+A string, the template to format tensor summaries into.
+END
+ }
+ attr {
+ name: "placeholder"
+ description: <<END
+A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+END
+ }
+ attr {
+ name: "summarize"
+ description: <<END
+When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+END
+ }
+ summary: "Formats a string template using a list of tensors."
+ description: <<END
+Formats a string template using a list of tensors, pretty-printing tensor summaries.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
index cc21ddc815..7d2fbcd00b 100644
--- a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
@@ -1,5 +1,15 @@
op {
graph_op_name: "StringLength"
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is counted to compute string length. One of: `"BYTE"` (for
+the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8
+encoded Unicode code points in each string). Results are undefined
+if `unit=UTF8_CHAR` and the `input` strings do not contain structurally
+valid UTF-8.
+END
+ }
in_arg {
name: "input"
description: <<END
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 8fc1e5cba3..5246090ab3 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -32,8 +32,10 @@ For each string in the input `Tensor`, creates a substring starting at index
If `len` defines a substring that would extend beyond the length of the input
string, then as many characters as possible are used.
-If `pos` is negative or specifies a character index larger than any of the input
-strings, then an `InvalidArgumentError` is thrown.
+A negative `pos` indicates distance within the string backwards from the end.
+
+If `pos` specifies an index which is out of range for any of the input strings,
+then an `InvalidArgumentError` is thrown.
`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
Op creation.
diff --git a/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
new file mode 100644
index 0000000000..0630f6e82a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "UpperBound"
+ visibility: HIDDEN
+ in_arg {
+ name: "sorted_inputs"
+ description: <<END
+2-D Tensor where each row is ordered.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A `Tensor` with the same shape as `values`. It contains the last scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+ }
+ summary: "Applies upper_bound(sorted_search_values, values) along each row."
+ description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently. The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='right')`.
+
+The result is not a global index to the entire
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = UpperBound(sorted_sequence, values)
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
index 1bc3660479..01387b7527 100644
--- a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
@@ -2,10 +2,31 @@ op {
visibility: HIDDEN
graph_op_name: "WindowDataset"
in_arg {
- name: "window_size"
+ name: "size"
description: <<END
A scalar representing the number of elements to accumulate in a window.
END
}
+ in_arg {
+ name: "shift"
+ description: <<END
+A scalar representing the steps moving the sliding window forward in one
+iteration. It must be positive.
+END
+ }
+ in_arg {
+ name: "stride"
+ description: <<END
+A scalar representing the stride of the input elements of the sliding window.
+It must be positive.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether a window should be dropped in case its size is
+smaller than desired.
+END
+ }
summary: "A dataset that creates window datasets from the input dataset."
}
diff --git a/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..e22d980424
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "PrintV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..8f0b1db45d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StringFormat"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
index 01c02e1f70..df012414e3 100644
--- a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "StringLength"
- endpoint {
- name: "strings.length"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 84c6285bbe..3843ea9e60 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -31,7 +31,7 @@ namespace tensorflow {
BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
bool allow_growth, const string& name)
- : suballocator_(sub_allocator),
+ : sub_allocator_(sub_allocator),
name_(name),
free_chunks_list_(kInvalidChunkHandle),
next_allocation_id_(1) {
@@ -72,7 +72,7 @@ BFCAllocator::~BFCAllocator() {
VLOG(2) << "Number of regions allocated: "
<< region_manager_.regions().size();
for (const auto& region : region_manager_.regions()) {
- suballocator_->Free(region.ptr(), region.memory_size());
+ sub_allocator_->Free(region.ptr(), region.memory_size());
}
for (BinNum b = 0; b < kNumBins; b++) {
@@ -108,7 +108,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
// Try allocating.
size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes);
- void* mem_addr = suballocator_->Alloc(alignment, bytes);
+ void* mem_addr = sub_allocator_->Alloc(alignment, bytes);
if (mem_addr == nullptr && !started_backpedal_) {
// Only backpedal once.
started_backpedal_ = true;
@@ -119,7 +119,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
while (mem_addr == nullptr) {
bytes = RoundedBytes(bytes * kBackpedalFactor);
if (bytes < rounded_bytes) break;
- mem_addr = suballocator_->Alloc(alignment, bytes);
+ mem_addr = sub_allocator_->Alloc(alignment, bytes);
}
}
@@ -158,10 +158,6 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
// Insert the chunk into the right bin.
InsertFreeChunkIntoBin(h);
- // Invoke visitors on newly allocated region.
- for (const auto& visitor : region_visitors_) {
- visitor(mem_addr, bytes);
- }
return true;
}
@@ -490,15 +486,6 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
InsertFreeChunkIntoBin(coalesced_chunk);
}
-void BFCAllocator::AddAllocVisitor(Visitor visitor) {
- VLOG(1) << "AddVisitor";
- mutex_lock l(lock_);
- region_visitors_.push_back(visitor);
- for (const auto& region : region_manager_.regions()) {
- visitor(region.ptr(), region.memory_size());
- }
-}
-
bool BFCAllocator::TracksAllocationSizes() { return true; }
size_t BFCAllocator::RequestedSize(const void* ptr) {
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 20e1dab1d5..2d74bf2b28 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -23,7 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/allocator_retry.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
@@ -42,7 +42,7 @@ namespace tensorflow {
// coalescing. One assumption we make is that the process using this
// allocator owns pretty much all of the memory, and that nearly
// all requests to allocate memory go through this interface.
-class BFCAllocator : public VisitableAllocator {
+class BFCAllocator : public Allocator {
public:
// Takes ownership of sub_allocator.
BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
@@ -55,11 +55,6 @@ class BFCAllocator : public VisitableAllocator {
const AllocationAttributes& allocation_attr) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
-
- // Does nothing, because memory is never freed.
- void AddFreeVisitor(Visitor visitor) override {}
-
bool TracksAllocationSizes() override;
size_t RequestedSize(const void* ptr) override;
@@ -309,7 +304,7 @@ class BFCAllocator : public VisitableAllocator {
};
// Returns 'bytes' rounded up to the next highest kMinAllocationSize.
- size_t RoundedBytes(size_t bytes);
+ static size_t RoundedBytes(size_t bytes);
// Try to add a new memory region that can satisfy an allocation of
// 'rounded_bytes' bytes. Returns true on success and false on
@@ -423,7 +418,7 @@ class BFCAllocator : public VisitableAllocator {
// of the available memory.
bool started_backpedal_ = false;
- std::unique_ptr<SubAllocator> suballocator_;
+ std::unique_ptr<SubAllocator> sub_allocator_;
string name_;
// Structures mutable after construction
@@ -435,9 +430,6 @@ class BFCAllocator : public VisitableAllocator {
// Pointer to head of linked list of free Chunks
ChunkHandle free_chunks_list_ GUARDED_BY(lock_);
- // Called once on each region, ASAP.
- std::vector<Visitor> region_visitors_ GUARDED_BY(lock_);
-
// Counter containing the next unique identifier to assign to a
// newly-created chunk.
int64 next_allocation_id_ GUARDED_BY(lock_);
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 97b6971c5b..99cb9ac6a0 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -61,6 +61,7 @@ bool ReadPartialShapesFromShapeMap(
shape_map,
std::vector<PartialTensorShape>* input_shapes) {
CHECK(shape_map != nullptr);
+ input_shapes->resize(n->num_inputs());
for (const Edge* in : n->in_edges()) {
// Don't need to check if incoming control edges have known shapes.
if (in->IsControlEdge()) continue;
@@ -71,7 +72,9 @@ bool ReadPartialShapesFromShapeMap(
}
const auto& known_shape = known_shape_iter->second;
CHECK_GT(known_shape.size(), in->src_output()) << known_shape_iter->first;
- input_shapes->push_back(known_shape[in->src_output()]);
+ DCHECK_GE(in->dst_input(), 0);
+ DCHECK_LT(in->dst_input(), input_shapes->size());
+ (*input_shapes)[in->dst_input()] = known_shape[in->src_output()];
}
return true;
}
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index f8cb854b52..d800a86199 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -347,7 +347,12 @@ namespace {
static Status WrappedTensorDeviceCopy(
const Tensor& from, Tensor* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
- if (DMAHelper::CanUseDMA(&from)) {
+ if (from.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ } else if (DMAHelper::CanUseDMA(&from)) {
TF_RETURN_IF_ERROR(copy(from, to));
} else {
*to = from;
@@ -358,7 +363,7 @@ static Status WrappedTensorDeviceCopy(
#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy)
+ Tensor, DIRECTION, WrappedTensorDeviceCopy)
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 81d68e3be4..2ef1547cd9 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -101,11 +101,21 @@ class Device : public DeviceBase {
}
}
+ // If true, and tracing is enabled, the `tracing::ScopedAnnotation()` tracing
+ // mechanism will be used instead of `tracing::ScopedActivity()`. Some devices
+ // may override this method to use annotations, which enable child activities
+ // (such as GPU kernel launches) to be related to the OpKernel invocation.
+ virtual bool TraceUsingAnnotations() const { return false; }
+
// Blocks until all operations queued on the device at the time of
// the call have completed. Returns any error pending on the device
// at completion.
virtual Status Sync() = 0;
+ // Override this to return true for devices that require a Sync() call before
+ // session completion.
+ virtual bool RequiresSyncOnCompletion() const { return false; }
+
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index eb388202fa..af5d5b17e7 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -1202,14 +1202,11 @@ Status DirectSession::CreateExecutors(
auto opseg = device->op_segment();
params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
- // We do not share the kernel via the OpSegment if the node is
- // stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
- if (!lib->IsStateful(ndef.op()) ||
- lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+ if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -1222,13 +1219,11 @@ Status DirectSession::CreateExecutors(
create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
- // If the node is stateful, opseg owns it. Otherwise, delete it.
- if (kernel && !lib->IsStateful(kernel->type_string())) {
+ if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
delete kernel;
- }
};
- optimizer.Optimize(lib, options_.env, device, &iter->second,
+ optimizer.Optimize(lib, options_.env, device, &partition_graph,
/*shape_map=*/nullptr);
// TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 3f2355e530..65e816c202 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -1255,7 +1255,7 @@ TEST(DirectSessionTest, RunHandleTest) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
- ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
@@ -1308,7 +1308,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
- ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 0b096a14a3..2ed4f69f90 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -77,6 +77,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
options.config.mutable_graph_options()
->mutable_rewrite_options()
->set_min_graph_nodes(-1);
+ options.config.mutable_graph_options()
+ ->mutable_rewrite_options()
+ ->set_pin_to_host_optimization(RewriterConfig::OFF);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 37fc031985..18420b60fd 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -32,6 +32,18 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
return default_val;
}
+std::unique_ptr<thread::ThreadPool> EagerThreadPool(
+ const SessionOptions& opts) {
+ SessionOptions opts_copy(opts);
+ if (opts_copy.config.inter_op_parallelism_threads() == 0) {
+ // Eager defaults to a single thread when no threads are specified.
+ opts_copy.config.set_inter_op_parallelism_threads(1);
+ }
+
+ return std::unique_ptr<thread::ThreadPool>(
+ NewThreadPoolFromSessionOptions(opts_copy));
+}
+
} // namespace
EagerContext::EagerContext(const SessionOptions& opts,
@@ -49,7 +61,7 @@ EagerContext::EagerContext(const SessionOptions& opts,
: policy_(default_policy),
devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous),
- thread_pool_(NewThreadPoolFromSessionOptions(opts)),
+ thread_pool_(EagerThreadPool(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
thread_pool_.get())),
@@ -66,13 +78,9 @@ EagerContext::EagerContext(const SessionOptions& opts,
local_unowned_device_manager_ = device_mgr;
}
InitDeviceMapAndAsync();
- if (opts.config.inter_op_parallelism_threads() > 0) {
- runner_ = [this](std::function<void()> closure) {
- this->thread_pool_->Schedule(closure);
- };
- } else {
- runner_ = [](std::function<void()> closure) { closure(); };
- }
+ runner_ = [this](std::function<void()> closure) {
+ this->thread_pool_->Schedule(std::move(closure));
+ };
}
void EagerContext::InitDeviceMapAndAsync() {
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 1da1326a9a..1bc63616d0 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -251,26 +251,6 @@ Status EagerLocalExecute(EagerOperation* op,
EagerContext* ctx = op->EagerContext();
auto status = ctx->GetStatus();
if (!status.ok()) return status;
- // Ensure all resource-touching ops run in the device the resource is,
- // regardless of anything else that has been specified. This is identical to
- // the graph mode behavior.
- for (int i = 0; i < op->Inputs().size(); ++i) {
- Device* input_op_device = nullptr;
- status = op->Inputs()[i]->OpDevice(&input_op_device);
- if (!status.ok()) return status;
- VLOG(2) << "for op " << op->Name() << " input " << i << " "
- << DataTypeString(op->Inputs()[i]->dtype) << " "
- << (input_op_device == nullptr ? "cpu" : input_op_device->name())
- << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
- if (op->Inputs()[i]->dtype == DT_RESOURCE &&
- (input_op_device != op->Device() || input_op_device == nullptr)) {
- Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
- VLOG(1) << "Changing device of operation " << op->Name() << " to "
- << d->name() << " because input #" << i
- << " is a resource in this device.";
- op->SetDevice(d);
- }
- }
Device* device = op->Device();
Fprint128 cache_key = op->MutableAttrs()->CacheKey(
@@ -604,6 +584,27 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
Status EagerExecute(EagerOperation* op,
gtl::InlinedVector<TensorHandle*, 2>* retvals,
int* num_retvals) {
+ // Ensure all resource-touching ops run in the device the resource is,
+ // regardless of anything else that has been specified. This is identical to
+ // the graph mode behavior.
+ EagerContext* ctx = op->EagerContext();
+ for (int i = 0; i < op->Inputs().size(); ++i) {
+ Device* input_op_device = nullptr;
+ auto status = op->Inputs()[i]->OpDevice(&input_op_device);
+ if (!status.ok()) return status;
+ VLOG(2) << "for op " << op->Name() << " input " << i << " "
+ << DataTypeString(op->Inputs()[i]->dtype) << " "
+ << (input_op_device == nullptr ? "cpu" : input_op_device->name())
+ << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
+ if (op->Inputs()[i]->dtype == DT_RESOURCE &&
+ (input_op_device != op->Device() || input_op_device == nullptr)) {
+ Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
+ VLOG(1) << "Changing device of operation " << op->Name() << " to "
+ << d->name() << " because input #" << i
+ << " is a resource in this device.";
+ op->SetDevice(d);
+ }
+ }
bool op_is_local = IsLocal(op->EagerContext(), op->Device());
if (op_is_local) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 59f94506b7..83d8425477 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -32,21 +32,6 @@ limitations under the License.
namespace tensorflow {
// static
-Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
- KernelAndDevice* out) {
- OpKernel* k = nullptr;
- Status s = CreateOpKernel(device->device_type().c_str(), device,
- device->GetAllocator(AllocatorAttributes()),
- nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
- out->device_ = device;
- out->kernel_.reset(k);
- out->flib_ = nullptr;
- out->runner_ = nullptr;
- out->default_runner_ = [](std::function<void()> f) { f(); };
- return s;
-}
-
-// static
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index ed76c4f601..04151a1171 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -52,9 +52,6 @@ class KernelAndDevice {
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out);
- // TODO(ashankar): Remove this
- static Status InitOp(Device* device, const NodeDef& ndef,
- KernelAndDevice* out);
KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
: device_(nullptr),
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index b912f7d37b..d58724cbfa 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
Status TensorHandle::NumDims(int* num_dims) {
if (IsRemote()) {
TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
- CHECK(remote_shape_ != nullptr);
*num_dims = remote_shape_->dims();
} else {
TF_RETURN_IF_ERROR(WaitReady());
@@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) {
return Status::OK();
}
+Status TensorHandle::NumElements(int64* num_elements) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *num_elements = remote_shape_->num_elements();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_elements != nullptr);
+
+ *num_elements = tensor_.NumElements();
+ }
+
+ return Status::OK();
+}
+
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 1bc9c6531a..e55f1a0338 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted {
Status NumDims(int* num_dims);
Status Dim(int dim_index, int64* dim);
+ Status NumElements(int64* num_elements);
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(int64* op_id, int32* output_num);
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 84865397bc..2c48084cab 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -76,56 +76,47 @@ bool IsInitializationOp(const Node* node) {
namespace nodestats {
inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
-void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) {
+void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
if (!stats) return;
stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
}
-void SetAllStart(NodeExecStatsWrapper* stats) {
+void SetAllStart(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordExecutorStarted();
}
-void SetOpStart(NodeExecStatsWrapper* stats) {
+void SetOpStart(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordComputeStarted();
}
-void SetOpEnd(NodeExecStatsWrapper* stats) {
+void SetOpEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordComputeEnded();
}
-void SetAllEnd(NodeExecStatsWrapper* stats) {
+void SetAllEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
stats->RecordExecutorEnded();
}
-void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
+void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
if (!stats) return;
stats->SetOutput(slot, v);
}
-void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
+void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
if (!stats) return;
stats->SetMemory(ctx);
}
-void SetReferencedTensors(NodeExecStatsWrapper* stats,
+void SetReferencedTensors(NodeExecStatsInterface* stats,
const TensorReferenceVector& tensors) {
if (!stats) return;
stats->SetReferencedTensors(tensors);
}
-// Sets the timeline_label field of *stats, using data from *node.
-// Returns true iff the node is a transfer node.
-bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
- if (!stats) {
- return false;
- }
- return stats->SetTimelineLabel(node);
-}
-
} // namespace nodestats
class ExecutorImpl;
@@ -152,6 +143,8 @@ struct NodeItem {
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
bool is_merge : 1; // True iff IsMerge(node)
bool is_enter : 1; // True iff IsEnter(node)
+ bool is_constant_enter : 1; // True iff IsEnter(node) and
+ // node->GetAttr("is_constant") == true.
bool is_exit : 1; // True iff IsExit(node)
bool is_control_trigger : 1; // True iff IsControlTrigger(node)
bool is_sink : 1; // True iff IsSink(node)
@@ -635,6 +628,14 @@ Status ExecutorImpl::Initialize() {
item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
item->is_merge = IsMerge(n);
item->is_enter = IsEnter(n);
+ if (item->is_enter) {
+ bool is_constant_enter;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
+ item->is_constant_enter = is_constant_enter;
+ } else {
+ item->is_constant_enter = false;
+ }
item->is_exit = IsExit(n);
item->is_control_trigger = IsControlTrigger(n);
item->is_sink = IsSink(n);
@@ -1237,6 +1238,9 @@ class ExecutorState {
// Step-local container.
ScopedStepContainer* step_container_;
StepStatsCollectorInterface* const stats_collector_;
+ const tracing::TraceCollector* const trace_collector_;
+ const tracing::EventCollector* const event_collector_;
+
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
@@ -1245,6 +1249,7 @@ class ExecutorState {
CancellationManager* cancellation_manager_;
Executor::Args::Runner runner_;
bool sync_on_finish_;
+ const bool trace_using_annotations_;
// Owned.
@@ -1301,7 +1306,7 @@ class ExecutorState {
// After item->kernel computation is done, processes its outputs.
Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
- EntryVector* outputs, NodeExecStatsWrapper* stats);
+ EntryVector* outputs, NodeExecStatsInterface* stats);
// After processing the outputs, propagates the outputs to their dsts.
// Contents of *outputs are left in an indeterminate state after
@@ -1312,7 +1317,7 @@ class ExecutorState {
// "node" just finishes. Takes ownership of "stats". Returns true if
// execution has completed.
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready);
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
@@ -1359,12 +1364,16 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
tensor_store_(args.tensor_store),
step_container_(args.step_container),
stats_collector_(args.stats_collector),
+ trace_collector_(tracing::GetTraceCollector()),
+ event_collector_(
+ tracing::GetEventCollector(tracing::EventCategory::kCompute)),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
cancellation_manager_(args.cancellation_manager),
runner_(args.runner),
sync_on_finish_(args.sync_on_finish),
+ trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()),
num_outstanding_ops_(0) {
// We start the entire execution in iteration 0 of the root frame
// so let us create the root frame and the state for iteration 0.
@@ -1513,7 +1522,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
struct ExecutorState::AsyncState {
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
const NodeItem* _item, Entry* _first_input,
- NodeExecStatsWrapper* _stats)
+ NodeExecStatsInterface* _stats)
: saved_inputs(*p.inputs),
saved_input_device_contexts(*p.input_device_contexts),
saved_input_alloc_attrs(*p.input_alloc_attrs),
@@ -1538,7 +1547,7 @@ struct ExecutorState::AsyncState {
const NodeItem* item;
Entry* first_input;
OpKernelContext ctx;
- NodeExecStatsWrapper* stats;
+ NodeExecStatsInterface* stats;
private:
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
@@ -1550,6 +1559,32 @@ struct ExecutorState::AsyncState {
}
};
+// Returns true if `item` might be traced by the given trace and event
+// collectors. Returns false only if `item` definitely will not be traced.
+bool MightTrace(const NodeItem& item,
+ const tracing::TraceCollector* trace_collector,
+ const tracing::EventCollector* event_collector,
+ bool using_annotations) {
+ // Tracing will only be enabled if either `event_collector` is non null,
+ // or `trace_collector` is non-null and enabled for this particular kernel.
+ // Although `tracing::ScopedActivity`,
+ // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of
+ // these properties internally in their constructors, the cost of passing the
+ // necessary arguments to them can be significant, so we avoid constructing
+ // them in the common case (when we know they will not be used).
+ if (event_collector != nullptr) {
+ return true;
+ }
+ if (trace_collector) {
+ if (using_annotations) {
+ return trace_collector->IsEnabledForAnnotations();
+ } else {
+ return trace_collector->IsEnabledForActivities(item.kernel_is_expensive);
+ }
+ }
+ return false;
+}
+
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
@@ -1583,7 +1618,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
params.stats_collector = stats_collector_;
Status s;
- NodeExecStatsWrapper* stats = nullptr;
+ NodeExecStatsInterface* stats = nullptr;
+
EntryVector outputs;
bool completed = false;
inline_ready.push_back(tagged_node);
@@ -1613,7 +1649,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
if (stats_collector_ && !tagged_node.is_dead) {
// track allocations if and only if we are collecting statistics
params.track_allocations = true;
- stats = new NodeExecStatsWrapper(node->name());
+ stats = stats_collector_->CreateNodeExecStats(node);
nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
@@ -1671,7 +1707,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
auto done = [this, state]() {
Device* device = impl_->params_.device;
- NodeExecStatsWrapper* stats = state->stats; // Shorthand
+ NodeExecStatsInterface* stats = state->stats; // Shorthand
Entry* first_input = state->first_input; // Shorthand
nodestats::SetOpEnd(stats);
@@ -1720,7 +1756,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
// Synchronous computes.
OpKernelContext ctx(&params, item.num_outputs);
nodestats::SetOpStart(stats);
- device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+
+ if (TF_PREDICT_FALSE(MightTrace(item, trace_collector_,
+ event_collector_,
+ trace_using_annotations_))) {
+ const string& op_name = op_kernel->name();
+ tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+ op_name);
+ if (trace_using_annotations_) {
+ // The OpKernel may create child activities (such as GPU kernel
+ // launches), so use a `ScopedAnnotation` to relate these activities
+ // in the trace.
+ tracing::ScopedAnnotation activity(op_name,
+ op_kernel->type_string());
+ device->Compute(op_kernel, &ctx);
+ } else {
+ // Use the cheaper `ScopedActivity` to trace just the OpKernel
+ // execution.
+ tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
+ item.kernel_is_expensive);
+ device->Compute(op_kernel, &ctx);
+ }
+ } else {
+ // In the common case, avoid creating any tracing objects.
+ device->Compute(op_kernel, &ctx);
+ }
+
nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
if (s.ok() && impl_->device_record_tensor_accesses_) {
@@ -1862,7 +1923,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
EntryVector* outputs,
- NodeExecStatsWrapper* stats) {
+ NodeExecStatsInterface* stats) {
const Node* node = item.node;
DCHECK_EQ(0, outputs->size());
outputs->resize(item.num_outputs);
@@ -1997,15 +2058,12 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
is_frame_done = input_frame->DecrementOutstandingOpsLocked(
&impl_->gview_, input_iter, ready);
} else if (item->is_enter) {
- bool is_constant;
- const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
- DCHECK(s.ok()) << s;
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
output_iter = 0;
{
const NodeItem* item = impl_->gview_.node(node->id());
mutex_lock l(output_frame->mu);
- if (is_constant) {
+ if (item->is_constant_enter) {
// Propagate to all active iterations if this is a loop invariant.
output_frame->AddLoopInv(item, (*outputs)[0], ready);
} else {
@@ -2080,16 +2138,15 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
bool ExecutorState::NodeDone(const Status& s, const Node* node,
const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
- if (stats_collector_ != nullptr &&
- !nodestats::SetTimelineLabel(node, stats)) {
- // Only record non-transfer nodes.
- // Transfers 'stats' ownership to 'stats_collector_'.
- stats_collector_->Save(impl_->params_.device->name(), stats);
- } else if (stats) {
- delete stats;
+ if (stats) {
+ if (stats_collector_) {
+ stats->Done(impl_->params_.device->name());
+ } else {
+ delete stats;
+ }
}
bool abort_run = false;
@@ -2311,13 +2368,15 @@ void ExecutorState::Finish() {
auto done_cb = std::move(done_cb_);
auto runner = std::move(runner_);
mu_.unlock();
- if (sync_on_finish_ && status.ok()) {
+ Device* device = impl_->params_.device;
+ if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
// Block until the device has finished all queued operations. For
// devices like GPUs that continue to execute Ops after their Compute
// methods have completed, this ensures that control is not returned to
// the user until the step (and its side-effects) has actually completed.
- status = impl_->params_.device->Sync();
+ status.Update(device->Sync());
}
+
delete this;
CHECK(done_cb != nullptr);
runner([=]() { done_cb(status); });
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 1c9b69721d..472865ca43 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -414,9 +414,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(
device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
&fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, graph_def_version_, &s);
- *kernel = new CallOp(handle, &construction);
- if (!s.ok()) {
- delete *kernel;
+ if (s.ok()) {
+ *kernel = new CallOp(handle, &construction);
}
return s;
}
diff --git a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
index 636cd43575..6bd29ef775 100644
--- a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
@@ -26,8 +26,12 @@ namespace tensorflow {
class CUDAHostAllocator : public SubAllocator {
public:
// Note: stream_exec cannot be null.
- explicit CUDAHostAllocator(se::StreamExecutor* stream_exec)
- : stream_exec_(stream_exec) {
+ explicit CUDAHostAllocator(se::StreamExecutor* stream_exec, int numa_node,
+ const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors),
+ stream_exec_(stream_exec),
+ numa_node_(numa_node) {
CHECK(stream_exec_ != nullptr);
}
~CUDAHostAllocator() override {}
@@ -39,19 +43,23 @@ class CUDAHostAllocator : public SubAllocator {
if (ptr == nullptr) {
LOG(WARNING) << "could not allocate pinned host memory of size: "
<< num_bytes;
+ return ptr;
}
+ VisitAlloc(ptr, numa_node_, num_bytes);
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
+ VisitFree(ptr, numa_node_, num_bytes);
stream_exec_->HostMemoryDeallocate(ptr);
}
}
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const int numa_node_;
TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 2d4c8d0201..42021e51f3 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -22,18 +22,48 @@ limitations under the License.
namespace tensorflow {
-GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const string& name)
- : GPUBFCAllocator(cuda_gpu_id, total_memory, GPUOptions(), name) {}
+bool GPUBFCAllocator::GetAllowGrowthValue(const GPUOptions& gpu_options) {
+ const char* force_allow_growth_string =
+ std::getenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ if (force_allow_growth_string == nullptr) {
+ return gpu_options.allow_growth();
+ }
+
+ if (strcmp("false", force_allow_growth_string) == 0) {
+ if (gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return false;
+ } else if (strcmp("true", force_allow_growth_string) == 0) {
+ if (!gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return true;
+ }
+
+ LOG(ERROR)
+ << "The TF_FORCE_GPU_ALLOW_GROWTH environment variable is set but could"
+ << " not be parsed: \"" << force_allow_growth_string << "\". Valid"
+ << " values are \"true\" or \"false\". Using original config value"
+ << " of " << gpu_options.allow_growth() << ".";
+ return gpu_options.allow_growth();
+}
+
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
+ size_t total_memory, const string& name)
+ : GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {}
-GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
+ size_t total_memory,
const GPUOptions& gpu_options,
const string& name)
- : BFCAllocator(
- new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(),
- gpu_options.per_process_gpu_memory_fraction() > 1.0 ||
- gpu_options.experimental().use_unified_memory()),
- total_memory, gpu_options.allow_growth(), name) {}
+ : BFCAllocator(sub_allocator, total_memory,
+ GPUBFCAllocator::GetAllowGrowthValue(gpu_options), name) {}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index f1cc2eace1..d4c9cee89a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -31,28 +31,20 @@ limitations under the License.
namespace tensorflow {
-// A GPU memory allocator that implements a 'best-fit with coalescing'
-// algorithm.
-class GPUBFCAllocator : public BFCAllocator {
- public:
- // 'cuda_gpu_id' refers to the ID of the GPU device within
- // the process and must reference a valid ID in the process.
- GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const string& name);
- GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const GPUOptions& gpu_options, const string& name);
- virtual ~GPUBFCAllocator() {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
-};
-
// Suballocator for GPU memory.
class GPUMemAllocator : public SubAllocator {
public:
+ // 'platform_gpu_id' refers to the ID of the GPU device within
+ // the process and must reference a valid ID in the process.
// Note: stream_exec cannot be null.
explicit GPUMemAllocator(se::StreamExecutor* stream_exec,
- bool use_unified_memory)
- : stream_exec_(stream_exec), use_unified_memory_(use_unified_memory) {
+ PlatformGpuId gpu_id, bool use_unified_memory,
+ const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors),
+ stream_exec_(stream_exec),
+ gpu_id_(gpu_id),
+ use_unified_memory_(use_unified_memory) {
CHECK(stream_exec_ != nullptr);
}
~GPUMemAllocator() override {}
@@ -65,12 +57,14 @@ class GPUMemAllocator : public SubAllocator {
} else {
ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque();
}
+ VisitAlloc(ptr, gpu_id_.value(), num_bytes);
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
+ VisitFree(ptr, gpu_id_.value(), num_bytes);
if (use_unified_memory_) {
stream_exec_->UnifiedMemoryDeallocate(ptr);
} else {
@@ -82,11 +76,28 @@ class GPUMemAllocator : public SubAllocator {
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const PlatformGpuId gpu_id_;
const bool use_unified_memory_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
};
+// A GPU memory allocator that implements a 'best-fit with coalescing'
+// algorithm.
+class GPUBFCAllocator : public BFCAllocator {
+ public:
+ GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+ const string& name);
+ GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+ const GPUOptions& gpu_options, const string& name);
+ ~GPUBFCAllocator() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+
+ private:
+ static bool GetAllowGrowthValue(const GPUOptions& gpu_options);
+};
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
index 67caeb3495..60e82ed13b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -46,7 +47,11 @@ static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use,
}
TEST(GPUBFCAllocatorTest, NoDups) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
// Allocate a lot of raw pointers
@@ -75,7 +80,11 @@ TEST(GPUBFCAllocatorTest, NoDups) {
}
TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Allocate 256 raw pointers of sizes between 100 bytes and about
// a meg
random::PhiloxRandom philox(123, 17);
@@ -133,7 +142,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
}
TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
float* first_ptr = a.Allocate<float>(1024);
@@ -168,18 +181,30 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
}
TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* ptr = a.Allocate<float>(0);
EXPECT_EQ(nullptr, ptr);
}
TEST(GPUBFCAllocatorTest, TracksSizes) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
EXPECT_EQ(256, a.AllocatedSize(t1));
@@ -187,8 +212,12 @@ TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
}
TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) {
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
// Configure a 1MiB byte limit
- GPUBFCAllocator a(CudaGpuId(0), 1 << 20, "GPU_0_bfc");
+ GPUBFCAllocator a(sub_allocator, 1 << 20, "GPU_0_bfc");
float* first_ptr = a.Allocate<float>(1 << 6);
float* second_ptr = a.Allocate<float>(1 << 20);
@@ -203,7 +232,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) {
options.set_allow_growth(true);
// Max of 2GiB, but starts out small.
- GPUBFCAllocator a(CudaGpuId(0), 1LL << 31, options, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1LL << 31, "GPU_0_bfc");
// Allocate 10 raw pointers of sizes between 100 bytes and about
// 64 megs.
@@ -264,8 +297,15 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) {
}
TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
- GPUBFCAllocator a(CudaGpuId(0), 1UL << 60, "GPU_0_bfc");
- GPUBFCAllocator b(CudaGpuId(0), 1UL << 60, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1UL << 60, "GPU_0_bfc");
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator b(sub_allocator, 1UL << 60, "GPU_0_bfc");
void* amem = a.AllocateRaw(1, 1);
void* bmem = b.AllocateRaw(1, 1 << 30);
a.DeallocateRaw(amem);
@@ -273,7 +313,11 @@ TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
}
static void BM_Allocation(int iters) {
- GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<size_t> sizes = {256, 4096, 16384, 524288,
512, 1048576, 10485760, 104857600,
@@ -289,7 +333,11 @@ static void BM_Allocation(int iters) {
BENCHMARK(BM_Allocation);
static void BM_AllocationThreaded(int iters, int num_threads) {
- GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
thread::ThreadPool pool(Env::Default(), "test", num_threads);
std::atomic_int_fast32_t count(iters);
mutex done_lock;
@@ -325,7 +373,11 @@ BENCHMARK(BM_AllocationThreaded)->Arg(1)->Arg(4)->Arg(16);
// A more complex benchmark that defers deallocation of an object for
// "delay" allocations.
static void BM_AllocationDelayed(int iters, int delay) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<int> sizes = {256, 4096, 16384, 4096, 512, 1024, 1024};
int size_index = 0;
@@ -358,12 +410,18 @@ BENCHMARK(BM_AllocationDelayed)->Arg(1)->Arg(10)->Arg(100)->Arg(1000);
class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
protected:
+ void SetUp() override { CHECK_EQ(unsetenv("TF_FORCE_GPU_ALLOW_GROWTH"), 0); }
+
// The following test methods are called from tests. The reason for this is
// that this class is a friend class to BFCAllocator, but tests are not, so
// only methods inside this class can access private members of BFCAllocator.
void TestBinDebugInfo() {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
std::vector<void*> initial_ptrs;
std::vector<size_t> initial_ptrs_allocated_sizes;
@@ -441,7 +499,11 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
}
void TestLog2FloorNonZeroSlow() {
- GPUBFCAllocator a(CudaGpuId(0), 1 /* total_memory */, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 /* total_memory */, "GPU_0_bfc");
EXPECT_EQ(-1, a.Log2FloorNonZeroSlow(0));
EXPECT_EQ(0, a.Log2FloorNonZeroSlow(1));
EXPECT_EQ(1, a.Log2FloorNonZeroSlow(2));
@@ -450,6 +512,56 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1024));
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1025));
}
+
+ void TestForceAllowGrowth() {
+ PlatformGpuId platform_gpu_id(0);
+ GPUOptions options;
+ // Unset flag value uses provided option.
+ unsetenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ options.set_allow_growth(true);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unset_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_0_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unset_flag_allocator.curr_region_allocation_bytes_);
+
+ // Unparseable flag value uses provided option.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "unparseable", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unparsable_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_1_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unparsable_flag_allocator.curr_region_allocation_bytes_);
+
+ // Max of 2GiB total memory. Env variable set forces allow_growth, which
+ // does an initial allocation of 1MiB.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "true", 1);
+ options.set_allow_growth(false);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_2_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ force_allow_growth_allocator.curr_region_allocation_bytes_);
+
+ // If env variable forces allow_growth disabled, all available memory is
+ // allocated.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "false", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_no_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_3_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(1LL << 31),
+ force_no_allow_growth_allocator.curr_region_allocation_bytes_);
+ }
};
TEST_F(GPUBFCAllocatorPrivateMethodsTest, BinDebugInfo) { TestBinDebugInfo(); }
@@ -458,6 +570,10 @@ TEST_F(GPUBFCAllocatorPrivateMethodsTest, Log2FloorNonZeroSlow) {
TestLog2FloorNonZeroSlow();
}
+TEST_F(GPUBFCAllocatorPrivateMethodsTest, ForceAllowGrowth) {
+ TestForceAllowGrowth();
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index 934a57a5fb..d85ca8892f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -27,10 +27,11 @@ limitations under the License.
namespace tensorflow {
-GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUcudaMallocAllocator::GPUcudaMallocAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUcudaMallocAllocator::~GPUcudaMallocAllocator() { delete base_allocator_; }
@@ -60,14 +61,6 @@ void GPUcudaMallocAllocator::DeallocateRaw(void* ptr) {
#endif // GOOGLE_CUDA
}
-void GPUcudaMallocAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUcudaMallocAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
bool GPUcudaMallocAllocator::TracksAllocationSizes() { return false; }
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 856fdc34b4..8df3724bc4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -29,20 +29,18 @@ namespace tensorflow {
// An allocator that wraps a GPU allocator and adds debugging
// functionality that verifies that users do not write outside their
// allocated memory.
-class GPUcudaMallocAllocator : public VisitableAllocator {
+class GPUcudaMallocAllocator : public Allocator {
public:
- explicit GPUcudaMallocAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUcudaMallocAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUcudaMallocAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
bool TracksAllocationSizes() override;
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index e4c834b30d..989ddbe4af 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -73,10 +73,11 @@ void InitMask(se::StreamExecutor* exec, void* ptr, int64* mask) {
// -----------------------------------------------------------------------------
// GPUDebugAllocator
// -----------------------------------------------------------------------------
-GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUDebugAllocator::GPUDebugAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; }
@@ -111,14 +112,6 @@ void GPUDebugAllocator::DeallocateRaw(void* ptr) {
base_allocator_->DeallocateRaw(ptr);
}
-void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
bool GPUDebugAllocator::TracksAllocationSizes() { return true; }
size_t GPUDebugAllocator::RequestedSize(const void* ptr) {
@@ -158,10 +151,11 @@ bool GPUDebugAllocator::CheckFooter(void* ptr) {
// -----------------------------------------------------------------------------
// GPUNanResetAllocator
// -----------------------------------------------------------------------------
-GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUNanResetAllocator::GPUNanResetAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; }
@@ -200,14 +194,6 @@ void GPUNanResetAllocator::DeallocateRaw(void* ptr) {
base_allocator_->DeallocateRaw(ptr);
}
-void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
size_t GPUNanResetAllocator::RequestedSize(const void* ptr) {
return base_allocator_->RequestedSize(ptr);
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index 0f9b72040c..17757a106c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -31,16 +31,14 @@ namespace tensorflow {
// An allocator that wraps a GPU allocator and adds debugging
// functionality that verifies that users do not write outside their
// allocated memory.
-class GPUDebugAllocator : public VisitableAllocator {
+class GPUDebugAllocator : public Allocator {
public:
- explicit GPUDebugAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUDebugAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUDebugAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
bool TracksAllocationSizes() override;
size_t RequestedSize(const void* ptr) override;
size_t AllocatedSize(const void* ptr) override;
@@ -53,7 +51,7 @@ class GPUDebugAllocator : public VisitableAllocator {
bool CheckFooter(void* ptr);
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
@@ -63,23 +61,21 @@ class GPUDebugAllocator : public VisitableAllocator {
// An allocator that wraps a GPU allocator and resets the memory on
// allocation and free to 'NaN', helping to identify cases where the
// user forgets to initialize the memory.
-class GPUNanResetAllocator : public VisitableAllocator {
+class GPUNanResetAllocator : public Allocator {
public:
- explicit GPUNanResetAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUNanResetAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUNanResetAllocator() override;
string Name() override { return "gpu_nan_reset"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
size_t RequestedSize(const void* ptr) override;
size_t AllocatedSize(const void* ptr) override;
void GetStats(AllocatorStats* stats) override;
void ClearStats() override;
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
index 236a0afa0b..aca08a7e33 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -34,10 +34,14 @@ namespace tensorflow {
namespace {
TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
for (int s : {8}) {
std::vector<int64> cpu_array(s);
@@ -58,11 +62,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) {
for (int s : {8, 211}) {
EXPECT_DEATH(
{
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
auto stream_exec =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -91,11 +98,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
for (int s : {8, 22}) {
EXPECT_DEATH(
{
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
auto stream_exec =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -121,10 +131,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
}
TEST(GPUDebugAllocatorTest, ResetToNan) {
- const CudaGpuId cuda_gpu_id(0);
- GPUNanResetAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUNanResetAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<float> cpu_array(1024);
std::vector<float> cpu_array_result(1024);
@@ -161,13 +175,17 @@ TEST(GPUDebugAllocatorTest, ResetToNan) {
}
TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
// NaN reset must be the outer-most allocator.
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
- new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<float> cpu_array(1024);
std::vector<float> cpu_array_result(1024);
@@ -204,18 +222,24 @@ TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
}
TEST(GPUDebugAllocatorTest, TracksSizes) {
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
- new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id),
- cuda_gpu_id);
+ new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id),
+ platform_gpu_id);
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
EXPECT_EQ(256, a.AllocatedSize(t1));
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 2763ac0d4a..d8ebdeff5d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -41,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -105,9 +104,9 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize);
stream_ = cuda_stream;
allocator_ = alloc;
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id.value()];
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ device_prop_ = &Eigen::m_deviceProperties[platform_gpu_id.value()];
}
const cudaStream_t& stream() const override { return *stream_; }
@@ -285,6 +284,38 @@ BaseGPUDevice::~BaseGPUDevice() {
for (auto ctx : device_contexts_) ctx->Unref();
}
+// This should be idempotent if already initialized.
+Status BaseGPUDevice::InitScratchBuffers() {
+ mutex_lock l(scratch_init_mutex_);
+ if (scratch_.size() < max_streams_) {
+ for (int i = 0; i < max_streams_; i++) {
+ DCHECK(streams_[i]);
+ if (scratch_.size() > i && scratch_[i]) continue;
+ size_t scratch_buffer_size =
+ Eigen::kCudaScratchSize + sizeof(unsigned int);
+ void* scratch_buffer = gpu_allocator_->AllocateRaw(
+ Allocator::kAllocatorAlignment, scratch_buffer_size);
+ if (scratch_buffer == nullptr) {
+ return errors::FailedPrecondition(
+ "Failed to allocate scratch buffer for device ",
+ tf_gpu_id_.value());
+ }
+ se::DeviceMemory<char> mem(
+ se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
+
+ bool ok = executor_->SynchronousMemZero(
+ &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
+ if (!ok) {
+ return errors::FailedPrecondition(
+ "Failed to memcopy into scratch buffer for device ",
+ tf_gpu_id_.value());
+ }
+ scratch_.push_back(static_cast<char*>(scratch_buffer));
+ }
+ }
+ return Status::OK();
+}
+
Status BaseGPUDevice::Init(const SessionOptions& options) {
auto executor_status = GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id_);
if (!executor_status.status().ok()) {
@@ -303,27 +334,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
for (int i = 0; i < max_streams_; i++) {
streams_.push_back(StreamGroupFactory::Global().GetOrCreate(
tf_gpu_id_, i, executor_, options.config.gpu_options()));
-
- size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
- void* scratch_buffer = gpu_allocator_->AllocateRaw(
- Allocator::kAllocatorAlignment, scratch_buffer_size);
- if (scratch_buffer == nullptr) {
- return errors::FailedPrecondition(
- "Failed to allocate scratch buffer for device ", tf_gpu_id_.value());
- }
- scratch_.push_back(static_cast<char*>(scratch_buffer));
-
- se::DeviceMemory<char> mem(
- se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
-
- bool ok = executor_->SynchronousMemZero(
- &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
- if (!ok) {
- return errors::FailedPrecondition(
- "Failed to memcopy into scratch buffer for device ",
- tf_gpu_id_.value());
- }
-
device_contexts_.push_back(new GPUDeviceContext(
i, streams_.back()->compute, streams_.back()->host_to_device,
streams_.back()->device_to_host, streams_.back()->device_to_device));
@@ -332,9 +342,10 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
gpu_device_info_->stream = streams_[0]->compute;
gpu_device_info_->default_context = device_contexts_[0];
gpu_device_info_->event_mgr = em_.get();
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
- gpu_device_info_->gpu_id = cuda_gpu_id.value();
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+ gpu_device_info_->gpu_id = platform_gpu_id.value();
set_tensorflow_gpu_device_info(gpu_device_info_);
// Whether and how the GPU device uses its own threadpool.
@@ -423,9 +434,6 @@ Status BaseGPUDevice::FillContextMap(const Graph* graph,
}
void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
- tracing::ScopedRegion region(tracing::EventCategory::kCompute,
- op_kernel->name());
-
// NOTE(tucker): We need to discriminate between Eigen GPU
// operations and all others. If an operation is Eigen
// implemented (or otherwise tries to launch a cuda kernel
@@ -439,8 +447,6 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
context->SetStatus(errors::Internal(
"Invalid synchronous 'Compute' on GPU for '_Recv' op"));
} else {
- tracing::ScopedAnnotation annotation(op_kernel->name(),
- op_kernel->type_string());
ComputeHelper(op_kernel, context);
}
}
@@ -690,9 +696,9 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
Eigen::GpuDevice device_;
};
-// Parse 'visible_device_list' into a list of CUDA GPU ids.
+// Parse 'visible_device_list' into a list of platform GPU ids.
Status ParseVisibleDeviceList(const string& visible_device_list,
- std::vector<CudaGpuId>* visible_gpu_order) {
+ std::vector<PlatformGpuId>* visible_gpu_order) {
visible_gpu_order->clear();
se::Platform* gpu_manager = GPUMachineManager();
@@ -707,26 +713,28 @@ Status ParseVisibleDeviceList(const string& visible_device_list,
} else {
const std::vector<string> order_str =
str_util::Split(visible_device_list, ',');
- for (const string& cuda_gpu_id_str : order_str) {
- int32 cuda_gpu_id;
- if (!strings::safe_strto32(cuda_gpu_id_str, &cuda_gpu_id)) {
+ for (const string& platform_gpu_id_str : order_str) {
+ int32 platform_gpu_id;
+ if (!strings::safe_strto32(platform_gpu_id_str, &platform_gpu_id)) {
return errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
- cuda_gpu_id_str, "'. visible_device_list = ", visible_device_list);
+ platform_gpu_id_str, "'. visible_device_list = ",
+ visible_device_list);
}
- if (cuda_gpu_id < 0 || cuda_gpu_id >= gpu_manager->VisibleDeviceCount()) {
+ if (platform_gpu_id < 0 ||
+ platform_gpu_id >= gpu_manager->VisibleDeviceCount()) {
return errors::InvalidArgument(
- "'visible_device_list' listed an invalid GPU id '", cuda_gpu_id,
+ "'visible_device_list' listed an invalid GPU id '", platform_gpu_id,
"' but visible device count is ",
gpu_manager->VisibleDeviceCount());
}
- visible_gpu_order->push_back(CudaGpuId(cuda_gpu_id));
+ visible_gpu_order->push_back(PlatformGpuId(platform_gpu_id));
}
}
// Validate no repeats.
- std::set<CudaGpuId> visible_device_set(visible_gpu_order->begin(),
- visible_gpu_order->end());
+ std::set<PlatformGpuId> visible_device_set(visible_gpu_order->begin(),
+ visible_gpu_order->end());
if (visible_device_set.size() != visible_gpu_order->size()) {
return errors::InvalidArgument(
"visible_device_list contained a duplicate entry: ",
@@ -737,8 +745,8 @@ Status ParseVisibleDeviceList(const string& visible_device_list,
Status VerifyVirtualDeviceSettings(
const size_t num_gpus_to_use, const GPUOptions& gpu_options,
- const std::vector<CudaGpuId>& visible_gpu_order,
- const std::vector<CudaGpuId>& valid_cuda_gpu_ids) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ const std::vector<PlatformGpuId>& valid_platform_gpu_ids) {
const auto& virtual_devices = gpu_options.experimental().virtual_devices();
CHECK(!virtual_devices.empty());
if (gpu_options.per_process_gpu_memory_fraction() > 0) {
@@ -760,11 +768,11 @@ Status VerifyVirtualDeviceSettings(
" #GPUs in visible_device_list: ", visible_gpu_order.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
- if (valid_cuda_gpu_ids.size() != virtual_devices.size()) {
+ if (valid_platform_gpu_ids.size() != virtual_devices.size()) {
return errors::Unknown(
"The number of valid GPUs doesn't match the number of elements in "
"the virtual_devices list.",
- " #valid GPUs: ", valid_cuda_gpu_ids.size(),
+ " #valid GPUs: ", valid_platform_gpu_ids.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
return Status::OK();
@@ -806,18 +814,18 @@ int64 MinSystemMemory(int64 available_memory) {
}
// Get the memory limit for the virtual device being created on GPU with
-// 'cuda_gpu_id', when that virtual device is the only virtual device being
+// 'platform_gpu_id', when that virtual device is the only virtual device being
// created on that GPU.
Status SingleVirtualDeviceMemoryLimit(const GPUOptions& gpu_options,
- CudaGpuId cuda_gpu_id,
+ PlatformGpuId platform_gpu_id,
int64* memory_limit) {
int64 total_memory = 0;
int64 available_memory = 0;
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
if (!se->DeviceMemoryUsage(&available_memory, &total_memory)) {
return errors::Unknown("Failed to query available memory for GPU ",
- cuda_gpu_id.value());
+ platform_gpu_id.value());
}
int64 allocated_memory = 0;
@@ -867,10 +875,11 @@ PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() {
return new ConcretePerOpGpuDevice();
}
-void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
- PerOpGpuDevice* device,
- DeviceContext* dc,
- Allocator* allocator) {
+Status BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
+ PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) {
+ TF_RETURN_IF_ERROR(InitScratchBuffers());
if (dc) {
const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
const int stream_id = gpu_dc->stream_id();
@@ -881,6 +890,7 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
} else {
ReinitializeDevice(context, device, 0, allocator);
}
+ return Status::OK();
}
Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
@@ -916,8 +926,8 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
num_gpus_to_use = iter->second;
}
const auto& gpu_options = options.config.gpu_options();
- std::vector<CudaGpuId> visible_gpu_order;
- std::vector<CudaGpuId> valid_cuda_gpu_ids;
+ std::vector<PlatformGpuId> visible_gpu_order;
+ std::vector<PlatformGpuId> valid_platform_gpu_ids;
// If we aren't going to use any GPUs, don't initialize them.
// We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
// because it treats an empty gpu_options.visible_device_list as 'all GPUs are
@@ -926,12 +936,12 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
&visible_gpu_order));
TF_RETURN_IF_ERROR(
- GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
+ GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids));
}
- if (num_gpus_to_use > valid_cuda_gpu_ids.size()) {
- num_gpus_to_use = valid_cuda_gpu_ids.size();
+ if (num_gpus_to_use > valid_platform_gpu_ids.size()) {
+ num_gpus_to_use = valid_platform_gpu_ids.size();
}
- if (!valid_cuda_gpu_ids.empty()) {
+ if (!valid_platform_gpu_ids.empty()) {
// Save the original device.
int original_device = 0;
cudaError_t err = cudaGetDevice(&original_device);
@@ -941,17 +951,18 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
}
// Force to implicitly initialize CUDA runtime on each valid GPU before
// CreateGPUDevice().
- for (CudaGpuId cuda_gpu_id : valid_cuda_gpu_ids) {
- err = cudaSetDevice(cuda_gpu_id.value());
+ for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
+ err = cudaSetDevice(platform_gpu_id.value());
if (err != cudaSuccess) {
- return errors::Internal("cudaSetDevice() on GPU:", cuda_gpu_id.value(),
- " failed. Status: ", cudaGetErrorString(err));
+ return errors::Internal("cudaSetDevice() on GPU:",
+ platform_gpu_id.value(), " failed. Status: ",
+ cudaGetErrorString(err));
}
err = cudaFree(nullptr);
if (err != cudaSuccess) {
- return errors::Internal(
- "CUDA runtime implicit initialization on GPU:", cuda_gpu_id.value(),
- " failed. Status: ", cudaGetErrorString(err));
+ return errors::Internal("CUDA runtime implicit initialization on GPU:",
+ platform_gpu_id.value(), " failed. Status: ",
+ cudaGetErrorString(err));
}
}
// Reset to the original device.
@@ -977,10 +988,10 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
LOG(INFO) << line_buf;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
line_buf = strings::StrCat(visible_gpu_order[i].value(), ": ");
- CudaGpuId cuda_id_i = visible_gpu_order[i];
+ PlatformGpuId gpu_id_i = visible_gpu_order[i];
for (int j = 0; j < visible_gpu_order.size(); ++j) {
- CudaGpuId cuda_id_j = visible_gpu_order[j];
- if (im.directed_links.find({cuda_id_i, cuda_id_j}) !=
+ PlatformGpuId gpu_id_j = visible_gpu_order[j];
+ if (im.directed_links.find({gpu_id_i, gpu_id_j}) !=
im.directed_links.end()) {
line_buf.append("Y ");
} else {
@@ -993,22 +1004,23 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
const auto& virtual_devices = gpu_options.experimental().virtual_devices();
if (!virtual_devices.empty()) {
- TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(
- num_gpus_to_use, gpu_options, visible_gpu_order, valid_cuda_gpu_ids));
+ TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(num_gpus_to_use, gpu_options,
+ visible_gpu_order,
+ valid_platform_gpu_ids));
// We've verified that num_gpus_to_use >= virtual_devices.size().
num_gpus_to_use = virtual_devices.size();
CHECK(gpu_options.visible_device_list().empty() ||
- valid_cuda_gpu_ids == visible_gpu_order);
+ valid_platform_gpu_ids == visible_gpu_order);
}
int next_tf_gpu_id = 0;
std::vector<int64> memory_limit_bytes;
for (int i = 0; i < num_gpus_to_use; ++i) {
- const CudaGpuId cuda_gpu_id = valid_cuda_gpu_ids[i];
+ const PlatformGpuId platform_gpu_id = valid_platform_gpu_ids[i];
if (virtual_devices.empty() ||
virtual_devices.Get(i).memory_limit_mb_size() == 0) {
int64 single_virtual_device_memory_limit = 0;
TF_RETURN_IF_ERROR(SingleVirtualDeviceMemoryLimit(
- gpu_options, cuda_gpu_id, &single_virtual_device_memory_limit));
+ gpu_options, platform_gpu_id, &single_virtual_device_memory_limit));
memory_limit_bytes.push_back(single_virtual_device_memory_limit);
} else {
const auto& memory_limit_mb = virtual_devices.Get(i).memory_limit_mb();
@@ -1021,7 +1033,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
TfGpuId tf_gpu_id(next_tf_gpu_id);
++next_tf_gpu_id;
TF_RETURN_IF_ERROR(
- GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id));
+ GpuIdManager::InsertTfPlatformGpuIdPair(tf_gpu_id, platform_gpu_id));
}
}
const int num_tf_gpus = next_tf_gpu_id;
@@ -1046,7 +1058,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
return Status::OK();
}
-static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
+static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
const se::DeviceDescription& desc) {
int cc_major;
int cc_minor;
@@ -1055,9 +1067,8 @@ static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
cc_minor = 0;
}
// LINT.IfChange
- return strings::StrCat("device: ", cuda_gpu_id.value(),
- ", name: ", desc.name(),
- ", pci bus id: ", desc.pci_bus_id(),
+ return strings::StrCat("device: ", platform_gpu_id.value(), ", name: ",
+ desc.name(), ", pci bus id: ", desc.pci_bus_id(),
", compute capability: ", cc_major, ".", cc_minor);
// LINT.ThenChange(//tensorflow/python/platform/test.py)
}
@@ -1072,12 +1083,13 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
const string device_name =
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
int numa_node = dev_locality.numa_node();
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
GPUProcessState* process_state = GPUProcessState::singleton();
Allocator* gpu_allocator = process_state->GetGPUAllocator(
@@ -1098,11 +1110,11 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
BaseGPUDevice* gpu_device = CreateGPUDevice(
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
- tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator,
- ProcessState::singleton()->GetCPUAllocator(numa_node));
+ tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
+ gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
LOG(INFO) << "Created TensorFlow device (" << device_name << " with "
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
- << GetShortDeviceDescription(cuda_gpu_id, desc) << ")";
+ << GetShortDeviceDescription(platform_gpu_id, desc) << ")";
TF_RETURN_IF_ERROR(gpu_device->Init(options));
devices->push_back(gpu_device);
@@ -1110,18 +1122,21 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
}
namespace {
-std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>>
+std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>>
GetPeerAccessMap(se::Platform* platform,
- const std::vector<CudaGpuId>& visible_gpu_order) {
- std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>> map(
- new std::map<std::pair<CudaGpuId, CudaGpuId>, bool>);
- for (CudaGpuId cuda_gpu_i : visible_gpu_order) {
- for (CudaGpuId cuda_gpu_j : visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
+ std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>> map(
+ new std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>);
+ for (PlatformGpuId platform_gpu_i : visible_gpu_order) {
+ for (PlatformGpuId platform_gpu_j : visible_gpu_order) {
se::StreamExecutor* from =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+ .ValueOrDie();
se::StreamExecutor* to =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
- (*map)[{cuda_gpu_i, cuda_gpu_j}] = from->CanEnablePeerAccessTo(to);
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+ .ValueOrDie();
+ (*map)[{platform_gpu_i, platform_gpu_j}] =
+ from->CanEnablePeerAccessTo(to);
}
}
@@ -1131,19 +1146,19 @@ GetPeerAccessMap(se::Platform* platform,
} // namespace
Status BaseGPUDeviceFactory::GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order, se::Platform* gpu_manager,
- std::vector<InterconnectMap>* maps) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ se::Platform* gpu_manager, std::vector<InterconnectMap>* maps) {
// The default interconnect map is obtained from the StreamExecutor.
auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order);
maps->resize(1);
InterconnectMap& imap = maps->at(0);
imap.name = "StreamExecutor";
imap.strength = InterconnectMap::kStreamExecutorStrength;
- for (CudaGpuId cuda_id_i : visible_gpu_order) {
- for (CudaGpuId cuda_id_j : visible_gpu_order) {
- if (cuda_id_i == cuda_id_j) continue;
- if ((*access_map)[{cuda_id_i, cuda_id_j}]) {
- imap.directed_links.insert({cuda_id_i, cuda_id_j});
+ for (PlatformGpuId gpu_id_i : visible_gpu_order) {
+ for (PlatformGpuId gpu_id_j : visible_gpu_order) {
+ if (gpu_id_i == gpu_id_j) continue;
+ if ((*access_map)[{gpu_id_i, gpu_id_j}]) {
+ imap.directed_links.insert({gpu_id_i, gpu_id_j});
}
}
}
@@ -1158,13 +1173,14 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
all_tf_gpu_ids.push_back(TfGpuId(i));
}
for (TfGpuId tf_gpu_id : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
// Get GPU bus_id from its reported NUMA affinity. Because GPUs are
// virtualized in some environments, we can't just use the GPU id.
// NUMA locales are indexed from 0, buses are indexed from 1.
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
int numa_node = desc.numa_node();
if (numa_node < 0) {
@@ -1174,7 +1190,8 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
// may run into trouble later with data transfer operations. The
// trouble may manifest as slower than expected performance, or
// outright failures.
- LOG(INFO) << "Could not identify NUMA node of CUDA gpu id " << cuda_gpu_id
+ LOG(INFO) << "Could not identify NUMA node of platform GPU id "
+ << platform_gpu_id
<< ", defaulting to 0. Your kernel may not have been built "
<< "with NUMA support.";
numa_node = 0;
@@ -1187,10 +1204,10 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
LocalLinks* links = dev_locality.mutable_links();
for (const InterconnectMap& imap : interconnects) {
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_dst;
+ PlatformGpuId platform_gpu_dst;
TF_RETURN_IF_ERROR(
- GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
- if (imap.directed_links.find({cuda_gpu_id, cuda_gpu_dst}) !=
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+ if (imap.directed_links.find({platform_gpu_id, platform_gpu_dst}) !=
imap.directed_links.end()) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
@@ -1204,10 +1221,10 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
// add high strength links to the others.
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
if (tf_gpu_id == tf_gpu_dst) continue;
- CudaGpuId cuda_gpu_dst;
+ PlatformGpuId platform_gpu_dst;
TF_RETURN_IF_ERROR(
- GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
- if (cuda_gpu_id == cuda_gpu_dst) {
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+ if (platform_gpu_id == platform_gpu_dst) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
ilink->set_type("SAME_DEVICE");
@@ -1216,9 +1233,9 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
}
(*localities)[tf_gpu_id] = dev_locality;
- VLOG(1) << "GPUDevice CudaGpuId " << cuda_gpu_id << " TfGpuId " << tf_gpu_id
- << " on bus " << dev_locality.bus_id() << " numa: " << numa_node
- << " pci: " << desc.pci_bus_id()
+ VLOG(1) << "GPUDevice PlatformGpuId " << platform_gpu_id << " TfGpuId "
+ << tf_gpu_id << " on bus " << dev_locality.bus_id()
+ << " numa: " << numa_node << " pci: " << desc.pci_bus_id()
<< " DeviceLocality: " << dev_locality.DebugString();
}
return Status::OK();
@@ -1226,14 +1243,14 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
static int GetDefaultMinGPUMultiprocessorCount(
se::Platform* gpu_manager,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
static const int kDefaultMinGPUMultiprocessorCount = 8;
// Find the highest multi-processor count across all visible GPUs.
int max_count = -1;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
auto exec_status =
- GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_order[i]);
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_order[i]);
if (!exec_status.ok()) {
continue;
}
@@ -1252,7 +1269,7 @@ static int GetDefaultMinGPUMultiprocessorCount(
static int GetMinGPUMultiprocessorCount(
se::Platform* gpu_manager,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
if (tf_min_gpu_core_count == nullptr ||
@@ -1330,18 +1347,20 @@ std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() {
}
Status EnablePeerAccess(se::Platform* platform,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
int possible_peer_count = 0;
int enabled_peer_count = 0;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId cuda_gpu_i = visible_gpu_order[i];
+ const PlatformGpuId platform_gpu_i = visible_gpu_order[i];
for (int j = 0; j < visible_gpu_order.size(); ++j) {
- const CudaGpuId cuda_gpu_j = visible_gpu_order[j];
+ const PlatformGpuId platform_gpu_j = visible_gpu_order[j];
// We have already validated that ExecutorForDevice() calls return OK.
se::StreamExecutor* from =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+ .ValueOrDie();
se::StreamExecutor* to =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+ .ValueOrDie();
if (from->CanEnablePeerAccessTo(to)) {
++possible_peer_count;
@@ -1349,7 +1368,8 @@ Status EnablePeerAccess(se::Platform* platform,
if (!status.ok()) {
LOG(WARNING)
<< "Unable to enable peer access between device ordinals "
- << cuda_gpu_i << " and " << cuda_gpu_j << ", status: " << status;
+ << platform_gpu_i << " and " << platform_gpu_j
+ << ", status: " << status;
} else {
++enabled_peer_count;
}
@@ -1372,22 +1392,23 @@ Status EnablePeerAccess(se::Platform* platform,
} // namespace
Status BaseGPUDeviceFactory::GetValidDeviceIds(
- const std::vector<CudaGpuId>& visible_gpu_order,
- std::vector<CudaGpuId>* ids) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ std::vector<PlatformGpuId>* ids) {
se::Platform* gpu_manager = GPUMachineManager();
bool new_gpu_found = false;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId cuda_gpu_id = visible_gpu_order[i];
+ const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
- // Only perform this once per visible cuda gpu id.
- if (visible_gpu_initialized_[cuda_gpu_id.value()]) {
+ // Only perform this once per visible platform gpu id.
+ if (visible_gpu_initialized_[visible_gpu_id.value()]) {
continue;
}
- visible_gpu_initialized_[cuda_gpu_id.value()] = true;
+ visible_gpu_initialized_[visible_gpu_id.value()] = true;
new_gpu_found = true;
- auto executor = GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, cuda_gpu_id);
+ auto executor =
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
if (!executor.ok()) {
return executor.status();
}
@@ -1435,9 +1456,9 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
// Filter out devices that don't have the right capability or power.
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId visible_gpu_id = visible_gpu_order[i];
+ const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
auto exec_status =
- GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_id);
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
if (!exec_status.ok()) {
LOG(INFO) << "Ignoring visible gpu device " << visible_gpu_id
<< " whose executor is in invalid state: "
@@ -1486,7 +1507,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
if (!ids->empty()) {
std::vector<int> raw_ids(ids->size());
std::transform(ids->begin(), ids->end(), raw_ids.begin(),
- [](CudaGpuId id) -> int { return id.value(); });
+ [](PlatformGpuId id) -> int { return id.value(); });
LOG(INFO) << "Adding visible gpu devices: "
<< str_util::Join(raw_ids, ", ");
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 56d03d7a8c..674e8384d5 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -65,6 +65,11 @@ class BaseGPUDevice : public LocalDevice {
// completes.
bool RequiresRecordingAccessedTensors() const override;
+ // GPU kernel execution requires us to use `tracing::ScopedAnnotation()`
+ // rather than `tracing::ScopedActivity()`, in order to relate asynchronously
+ // launched GPU kernels to the OpKernel.
+ bool TraceUsingAnnotations() const { return true; }
+
void ConsumeListOfAccessedTensors(
DeviceContext* device_context,
const TensorReferenceVector& tensor_refs) override;
@@ -86,15 +91,16 @@ class BaseGPUDevice : public LocalDevice {
// The caller owns the returned device.
PerOpGpuDevice* MakeGpuDevice() override;
- void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
- DeviceContext* dc, Allocator* allocator) override;
+ Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) override;
- // Returns the CUDA GPU id of this device within the native driver system;
+ // Returns the platform GPU id of this device within the native driver system;
// e.g., for CUDA this is the ordinal of the GPU within the system.
int gpu_id() const {
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
- return cuda_gpu_id.value();
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+ return platform_gpu_id.value();
}
// The executor that provides control for the device; e.g., for CUDA this
@@ -125,6 +131,7 @@ class BaseGPUDevice : public LocalDevice {
class StreamGroupFactory;
gtl::InlinedVector<StreamGroup*, 4> streams_;
+ mutex scratch_init_mutex_;
gtl::InlinedVector<char*, 4> scratch_;
std::vector<GPUDeviceContext*> device_contexts_;
GpuDeviceInfo* gpu_device_info_ = nullptr;
@@ -135,6 +142,9 @@ class BaseGPUDevice : public LocalDevice {
std::unique_ptr<EventMgr> em_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
+ // Initialize scractch buffers used by Eigen.
+ Status InitScratchBuffers();
+
void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device,
int stream_id, Allocator* allocator);
@@ -168,14 +178,14 @@ class BaseGPUDeviceFactory : public DeviceFactory {
int32 strength;
static const int kSameDeviceStrength;
static const int kStreamExecutorStrength;
- std::set<std::pair<CudaGpuId, CudaGpuId>> directed_links;
+ std::set<std::pair<PlatformGpuId, PlatformGpuId>> directed_links;
};
protected:
// Populates *maps with interconnect maps for all local direct access
// pathways between GPUs.
virtual Status GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order,
+ const std::vector<PlatformGpuId>& visible_gpu_order,
se::Platform* gpu_manager, std::vector<InterconnectMap>* maps);
struct TfGpuIdHash {
@@ -207,16 +217,16 @@ class BaseGPUDeviceFactory : public DeviceFactory {
Allocator* gpu_allocator,
Allocator* cpu_allocator) = 0;
- // Returns into 'ids' the list of valid CUDA GPU ids, in the order that
+ // Returns into 'ids' the list of valid platform GPU ids, in the order that
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
// based upon 'visible_gpu_order' which was generated by parsing
// GPUOptions::visible_device_list which is a comma-separated list of CUDA GPU
// ids.
- Status GetValidDeviceIds(const std::vector<CudaGpuId>& visible_gpu_order,
- std::vector<CudaGpuId>* ids);
+ Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
+ std::vector<PlatformGpuId>* ids);
- // visible_gpu_initialized_[cuda_gpu_id] is true if visible GPU cuda_gpu_id
- // has been initialized by the process.
+ // visible_gpu_initialized_[platform_gpu_id] is true if visible GPU
+ // platform_gpu_id has been initialized by the process.
std::unordered_map<int, bool> visible_gpu_initialized_;
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index daf59f0560..36294094e9 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -30,18 +30,21 @@ namespace tensorflow {
namespace {
const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0";
-int64 GetTotalGPUMemory(CudaGpuId gpu_id) {
+int64 GetTotalGPUMemory(PlatformGpuId gpu_id) {
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+ .ValueOrDie();
int64 total_memory, available_memory;
CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
return total_memory;
}
-Status GetComputeCapability(CudaGpuId gpu_id, int* cc_major, int* cc_minor) {
+Status GetComputeCapability(PlatformGpuId gpu_id, int* cc_major,
+ int* cc_minor) {
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+ .ValueOrDie();
if (!se->GetDeviceDescription().cuda_compute_capability(cc_major, cc_minor)) {
*cc_major = 0;
*cc_minor = 0;
@@ -223,7 +226,7 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
// error.
TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
int cc_major, cc_minor;
- TF_ASSERT_OK(GetComputeCapability(CudaGpuId(0), &cc_major, &cc_minor));
+ TF_ASSERT_OK(GetComputeCapability(PlatformGpuId(0), &cc_major, &cc_minor));
// Exit early while running on Pascal or later GPUs.
if (cc_major >= 6) {
return;
@@ -244,10 +247,10 @@ TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
// more memory than what is available on the device.
TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
static constexpr double kGpuMemoryFraction = 1.2;
- static constexpr CudaGpuId kCudaGpuId(0);
+ static constexpr PlatformGpuId kPlatformGpuId(0);
int cc_major, cc_minor;
- TF_ASSERT_OK(GetComputeCapability(kCudaGpuId, &cc_major, &cc_minor));
+ TF_ASSERT_OK(GetComputeCapability(kPlatformGpuId, &cc_major, &cc_minor));
// Exit early if running on pre-Pascal GPUs.
if (cc_major < 6) {
LOG(INFO)
@@ -262,7 +265,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
ASSERT_EQ(1, devices.size());
int64 memory_limit = devices[0]->attributes().memory_limit();
- ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kCudaGpuId) *
+ ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kPlatformGpuId) *
kGpuMemoryFraction));
AllocatorAttributes allocator_attributes = AllocatorAttributes();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h
index 2a6caea296..f0d9022821 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id.h
@@ -25,10 +25,10 @@ namespace tensorflow {
// physical machine, it can be filtered by CUDA environment variable
// CUDA_VISIBLE_DEVICES. Note that this id is not visible to Tensorflow, but
// result after filtering by CUDA_VISIBLE_DEVICES is visible to TF and is
-// called CUDA GPU id as below. See
+// called platform GPU id as below. See
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
// for more details.
-// - CUDA GPU id (also called *visible* GPU id in
+// - *platform* GPU id (also called *visible* GPU id in
// third_party/tensorflow/core/protobuf/config.proto): this is the id that is
// visible to Tensorflow after filtering by CUDA_VISIBLE_DEVICES, and is
// generated by the CUDA GPU driver. It starts from 0 and is used for CUDA API
@@ -39,14 +39,14 @@ namespace tensorflow {
// field of the device name "/device:GPU:<id>", and is also the identifier of
// a BaseGPUDevice. Note that the configuration allows us to create multiple
// BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the
-// hardware, so the mapping between TF GPU id and CUDA GPU id is not a 1:1
+// hardware, so the mapping between TF GPU id and platform GPU id is not a 1:1
// mapping, see the example below.
//
// For example, assuming that in the machine we have GPU device with index 0, 1,
// 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create
-// the following mapping between CUDA GPU id and physical GPU id:
+// the following mapping between platform GPU id and physical GPU id:
//
-// CUDA GPU id -> physical GPU id
+// platform GPU id -> physical GPU id
// 0 -> 1
// 1 -> 2
// 2 -> 3
@@ -56,32 +56,32 @@ namespace tensorflow {
//
// Assuming we configure the Session to create one BaseGPUDevice per GPU
// hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
//
-// TF GPU id -> CUDA GPU ID
+// TF GPU id -> platform GPU ID
// 0 (i.e. /device:GPU:0) -> 2
// 1 (i.e. /device:GPU:1) -> 0
//
-// Note that CUDA GPU id 1 is filtered out by GPUOptions::visible_device_list,
-// so it won't be used by the TF process.
+// Note that platform GPU id 1 is filtered out by
+// GPUOptions::visible_device_list, so it won't be used by the TF process.
//
// On the other hand, if we configure it to create 2 BaseGPUDevice per GPU
// hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
//
-// TF GPU id -> CUDA GPU ID
+// TF GPU id -> platform GPU ID
// 0 (i.e. /device:GPU:0) -> 2
// 1 (i.e. /device:GPU:1) -> 2
// 2 (i.e. /device:GPU:2) -> 0
// 3 (i.e. /device:GPU:3) -> 0
//
-// We create strong-typed integer classes for both TF GPU id and CUDA GPU id to
-// minimize programming errors and improve code readability. Except for the
+// We create strong-typed integer classes for both TF GPU id and platform GPU id
+// to minimize programming errors and improve code readability. Except for the
// StreamExecutor interface (as we don't change its API), whenever we need a
-// TF GPU id (or CUDA GPU id) we should use TfGpuId (or CudaGpuId) instead of a
-// raw integer.
+// TF GPU id (or platform GPU id) we should use TfGpuId (or PlatformGpuId)
+// instead of a raw integer.
TF_LIB_GTL_DEFINE_INT_TYPE(TfGpuId, int32);
-TF_LIB_GTL_DEFINE_INT_TYPE(CudaGpuId, int32);
+TF_LIB_GTL_DEFINE_INT_TYPE(PlatformGpuId, int32);
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
index b5099dc8ef..2b40730119 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
@@ -26,26 +26,27 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Manages the map between TfGpuId and CUDA GPU id.
-class TfToCudaGpuIdMap {
+// Manages the map between TfGpuId and platform GPU id.
+class TfToPlatformGpuIdMap {
public:
- static TfToCudaGpuIdMap* singleton() {
- static auto* id_map = new TfToCudaGpuIdMap;
+ static TfToPlatformGpuIdMap* singleton() {
+ static auto* id_map = new TfToPlatformGpuIdMap;
return id_map;
}
- Status Insert(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) LOCKS_EXCLUDED(mu_) {
+ Status Insert(TfGpuId tf_gpu_id, PlatformGpuId platform_gpu_id)
+ LOCKS_EXCLUDED(mu_) {
std::pair<IdMapType::iterator, bool> result;
{
mutex_lock lock(mu_);
- result = id_map_.insert({tf_gpu_id.value(), cuda_gpu_id.value()});
+ result = id_map_.insert({tf_gpu_id.value(), platform_gpu_id.value()});
}
- if (!result.second && cuda_gpu_id.value() != result.first->second) {
+ if (!result.second && platform_gpu_id.value() != result.first->second) {
return errors::AlreadyExists(
"TensorFlow device (GPU:", tf_gpu_id.value(),
") is being mapped to "
"multiple CUDA devices (",
- cuda_gpu_id.value(), " now, and ", result.first->second,
+ platform_gpu_id.value(), " now, and ", result.first->second,
" previously), which is not supported. "
"This may be the result of providing different GPU configurations "
"(ConfigProto.gpu_options, for example different visible_device_list)"
@@ -56,17 +57,17 @@ class TfToCudaGpuIdMap {
return Status::OK();
}
- bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const
+ bool Find(TfGpuId tf_gpu_id, PlatformGpuId* platform_gpu_id) const
LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
auto result = id_map_.find(tf_gpu_id.value());
if (result == id_map_.end()) return false;
- *cuda_gpu_id = result->second;
+ *platform_gpu_id = result->second;
return true;
}
private:
- TfToCudaGpuIdMap() = default;
+ TfToPlatformGpuIdMap() = default;
void TestOnlyReset() LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
@@ -78,17 +79,18 @@ class TfToCudaGpuIdMap {
IdMapType id_map_ GUARDED_BY(mu_);
friend class ::tensorflow::GpuIdManager;
- TF_DISALLOW_COPY_AND_ASSIGN(TfToCudaGpuIdMap);
+ TF_DISALLOW_COPY_AND_ASSIGN(TfToPlatformGpuIdMap);
};
} // namespace
-Status GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id,
- CudaGpuId cuda_gpu_id) {
- return TfToCudaGpuIdMap::singleton()->Insert(tf_gpu_id, cuda_gpu_id);
+Status GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+ PlatformGpuId platform_gpu_id) {
+ return TfToPlatformGpuIdMap::singleton()->Insert(tf_gpu_id, platform_gpu_id);
}
-Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
- if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) {
+Status GpuIdManager::TfToPlatformGpuId(TfGpuId tf_gpu_id,
+ PlatformGpuId* platform_gpu_id) {
+ if (TfToPlatformGpuIdMap::singleton()->Find(tf_gpu_id, platform_gpu_id)) {
return Status::OK();
}
return errors::NotFound("TensorFlow device GPU:", tf_gpu_id.value(),
@@ -96,7 +98,7 @@ Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
}
void GpuIdManager::TestOnlyReset() {
- TfToCudaGpuIdMap::singleton()->TestOnlyReset();
+ TfToPlatformGpuIdMap::singleton()->TestOnlyReset();
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
index 491d92ccdd..62df4310c4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
@@ -21,15 +21,17 @@ limitations under the License.
namespace tensorflow {
-// Class that maintains a map from TfGpuId to CudaGpuId, and manages the
+// Class that maintains a map from TfGpuId to PlatformGpuId, and manages the
// translation between them.
class GpuIdManager {
public:
- // Adds a mapping from tf_gpu_id to cuda_gpu_id.
- static Status InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id);
+ // Adds a mapping from tf_gpu_id to platform_gpu_id.
+ static Status InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+ PlatformGpuId platform_gpu_id);
- // Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found.
- static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id);
+ // Gets the platform_gpu_id associated with tf_gpu_id. Returns OK if found.
+ static Status TfToPlatformGpuId(TfGpuId tf_gpu_id,
+ PlatformGpuId* platform_gpu_id);
// Clears the map. Used in unit tests only.
static void TestOnlyReset();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
index a663ec7051..8bf3c6a308 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
@@ -22,38 +22,38 @@ limitations under the License.
namespace tensorflow {
namespace {
-CudaGpuId TfToCudaGpuId(TfGpuId tf) {
- CudaGpuId cuda;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf, &cuda));
- return cuda;
+PlatformGpuId TfToPlatformGpuId(TfGpuId tf) {
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf, &platform_gpu_id));
+ return platform_gpu_id;
}
TEST(GpuIdManagerTest, Basics) {
TfGpuId key_0(0);
- CudaGpuId value_0(0);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
- EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+ PlatformGpuId value_0(0);
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
// Multiple calls to map the same value is ok.
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
- EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
// Map a different TfGpuId to a different value.
TfGpuId key_1(3);
- CudaGpuId value_1(2);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1));
- EXPECT_EQ(value_1, TfToCudaGpuId(key_1));
+ PlatformGpuId value_1(2);
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_1, value_1));
+ EXPECT_EQ(value_1, TfToPlatformGpuId(key_1));
// Mapping a different TfGpuId to the same value is ok.
TfGpuId key_2(10);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1));
- EXPECT_EQ(value_1, TfToCudaGpuId(key_2));
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_1));
+ EXPECT_EQ(value_1, TfToPlatformGpuId(key_2));
// Mapping the same TfGpuId to a different value.
- ASSERT_FALSE(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0).ok());
+ ASSERT_FALSE(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_0).ok());
// Getting a nonexistent mapping.
- ASSERT_FALSE(GpuIdManager::TfToCudaGpuId(TfGpuId(100), &value_0).ok());
+ ASSERT_FALSE(GpuIdManager::TfToPlatformGpuId(TfGpuId(100), &value_0).ok());
}
} // namespace
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
index b9c66b3328..b1f10fb1dc 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
@@ -24,34 +24,37 @@ limitations under the License.
namespace tensorflow {
-// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids.
+// Utility methods for translation between Tensorflow GPU ids and platform GPU
+// ids.
class GpuIdUtil {
public:
// Convenient methods for getting the associated executor given a TfGpuId or
- // CudaGpuId.
- static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
- se::Platform* gpu_manager, CudaGpuId cuda_gpu_id) {
- return gpu_manager->ExecutorForDevice(cuda_gpu_id.value());
+ // PlatformGpuId.
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+ se::Platform* gpu_manager, PlatformGpuId platform_gpu_id) {
+ return gpu_manager->ExecutorForDevice(platform_gpu_id.value());
}
- static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
- CudaGpuId cuda_gpu_id) {
- return ExecutorForCudaGpuId(GPUMachineManager(), cuda_gpu_id);
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+ PlatformGpuId platform_gpu_id) {
+ return ExecutorForPlatformGpuId(GPUMachineManager(), platform_gpu_id);
}
static se::port::StatusOr<se::StreamExecutor*> ExecutorForTfGpuId(
TfGpuId tf_gpu_id) {
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- return ExecutorForCudaGpuId(cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ return ExecutorForPlatformGpuId(platform_gpu_id);
}
- // Verify that the cuda_gpu_id associated with a TfGpuId is legitimate.
+ // Verify that the platform_gpu_id associated with a TfGpuId is legitimate.
static void CheckValidTfGpuId(TfGpuId tf_gpu_id) {
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
const int visible_device_count = GPUMachineManager()->VisibleDeviceCount();
- CHECK_LT(cuda_gpu_id.value(), visible_device_count)
- << "cuda_gpu_id is outside discovered device range."
- << " TF GPU id: " << tf_gpu_id << " CUDA GPU id: " << cuda_gpu_id
+ CHECK_LT(platform_gpu_id.value(), visible_device_count)
+ << "platform_gpu_id is outside discovered device range."
+ << " TF GPU id: " << tf_gpu_id
+ << " platform GPU id: " << platform_gpu_id
<< " visible device count: " << visible_device_count;
}
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
index b18688174d..3e95374fda 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
@@ -76,12 +76,16 @@ GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) {
// This function is defined for debugging problems with the allocators.
GPUProcessState::~GPUProcessState() {
CHECK_EQ(this, instance_);
- for (auto p : gpu_allocators_) {
- delete p;
- }
instance_ = nullptr;
}
+int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) {
+ // Return the NUMA node associated with the GPU's StreamExecutor.
+ se::StreamExecutor* se =
+ GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
+ return se->GetDeviceDescription().numa_node();
+}
+
Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
TfGpuId tf_gpu_id,
size_t total_bytes) {
@@ -93,64 +97,63 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
gpu_allocators_.resize(tf_gpu_id.value() + 1);
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- gpu_al_.resize(tf_gpu_id.value() + 1);
}
- if (gpu_allocators_[tf_gpu_id.value()] == nullptr) {
- VisitableAllocator* gpu_allocator;
-
+ AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
+ if (allocator_parts.allocator.get() == nullptr) {
// Validate allocator types.
if (!allocator_type.empty() && allocator_type != "BFC") {
LOG(ERROR) << "Invalid allocator type: " << allocator_type;
return nullptr;
}
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- gpu_allocator =
- new GPUBFCAllocator(cuda_gpu_id, total_bytes, options,
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ int bus_id = BusIdForGPU(tf_gpu_id);
+ while (bus_id >= gpu_visitors_.size()) {
+ gpu_visitors_.push_back({});
+ }
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id,
+ (options.per_process_gpu_memory_fraction() > 1.0 ||
+ options.experimental().use_unified_memory()),
+ gpu_visitors_[bus_id], {});
+ Allocator* gpu_allocator =
+ new GPUBFCAllocator(sub_allocator, total_bytes, options,
strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
// If true, checks for memory overwrites by writing
// distinctive patterns on both ends of allocated memory.
if (useCudaMemoryGuardAllocator()) {
- gpu_allocator = new GPUDebugAllocator(gpu_allocator, cuda_gpu_id);
- gpu_allocator = new GPUNanResetAllocator(gpu_allocator, cuda_gpu_id);
+ gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id);
+ gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id);
} else if (useCudaMallocAllocator()) {
// If true, passes all allocation requests through to cudaMalloc
// useful for doing memory debugging with tools like cuda-memcheck
// **WARNING** probably will not work in a multi-gpu scenario
- gpu_allocator = new GPUcudaMallocAllocator(gpu_allocator, cuda_gpu_id);
- }
- gpu_allocators_[tf_gpu_id.value()] = gpu_allocator;
-
- // If there are any pending AllocVisitors for this bus, add
- // them now.
- se::StreamExecutor* se =
- GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
- int bus_id = se->GetDeviceDescription().numa_node();
- if (bus_id >= 0 && bus_id < static_cast<int64>(gpu_visitors_.size())) {
- for (const auto& v : gpu_visitors_[bus_id]) {
- gpu_allocator->AddAllocVisitor(v);
- }
+ gpu_allocator =
+ new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
}
+
+ Allocator* recording_allocator = nullptr;
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::GPU;
- md.dev_index = cuda_gpu_id.value();
+ md.dev_index = platform_gpu_id.value();
md.gpu_registered = false;
md.nic_registered = true;
- if (static_cast<int64>(gpu_al_.size()) <= tf_gpu_id.value()) {
- gpu_al_.resize(tf_gpu_id.value() + 1);
- }
- gpu_al_[tf_gpu_id.value()] = new internal::RecordingAllocator(
+ recording_allocator = new internal::RecordingAllocator(
&process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
}
+ allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator), sub_allocator,
+ std::unique_ptr<Allocator>(recording_allocator)};
+ }
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ return allocator_parts.recording_allocator.get();
+ } else {
+ return allocator_parts.allocator.get();
}
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- return gpu_al_[tf_gpu_id.value()];
- return gpu_allocators_[tf_gpu_id.value()];
#else
LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
return nullptr;
@@ -172,11 +175,12 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
tf_shared_lock lock(mu_);
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
- static_cast<int>(cuda_al_.size()) > 0) {
- return cuda_al_[0];
+ !cuda_host_allocators_.empty() &&
+ cuda_host_allocators_[0].recording_allocator != nullptr) {
+ return cuda_host_allocators_[0].recording_allocator.get();
}
if (static_cast<int>(cuda_host_allocators_.size()) > numa_node) {
- return cuda_host_allocators_[0];
+ return cuda_host_allocators_[0].allocator.get();
}
}
@@ -190,7 +194,7 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
// it knows is valid.
se::StreamExecutor* se = nullptr;
for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
- if (gpu_allocators_[i] != nullptr) {
+ if (gpu_allocators_[i].allocator != nullptr) {
se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
break;
}
@@ -199,6 +203,15 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
CHECK_NE(nullptr, se);
while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
+ while (cuda_host_alloc_visitors_.size() <= numa_node) {
+ cuda_host_alloc_visitors_.push_back({});
+ }
+ while (cuda_host_free_visitors_.size() <= numa_node) {
+ cuda_host_free_visitors_.push_back({});
+ }
+ SubAllocator* sub_allocator = new CUDAHostAllocator(
+ se, numa_node, cuda_host_alloc_visitors_[numa_node],
+ cuda_host_free_visitors_[numa_node]);
// TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
int64 cuda_host_mem_limit_in_mb = -1;
Status status = ReadInt64FromEnvVar("TF_CUDA_HOST_MEM_LIMIT_IN_MB",
@@ -208,62 +221,92 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message();
}
int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20);
- VisitableAllocator* allocator =
- new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit,
+ Allocator* allocator =
+ new BFCAllocator(sub_allocator, cuda_host_mem_limit,
true /*allow_growth*/, "cuda_host_bfc" /*name*/);
- if (LogMemory::IsEnabled()) {
+ if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
+ allocator = new TrackingAllocator(allocator, true);
}
- cuda_host_allocators_.push_back(allocator);
+ cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
+ sub_allocator,
+ std::unique_ptr<Allocator>(nullptr)});
+ AllocatorParts& allocator_parts = cuda_host_allocators_.back();
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::CPU;
md.dev_index = 0;
md.gpu_registered = true;
md.nic_registered = false;
- cuda_al_.push_back(new internal::RecordingAllocator(
- &process_state_->mem_desc_map_, cuda_host_allocators_.back(), md,
- &mu_));
+ allocator_parts.recording_allocator.reset(
+ new internal::RecordingAllocator(&process_state_->mem_desc_map_,
+ allocator_parts.allocator.get(), md,
+ &mu_));
}
}
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- return cuda_al_[0];
- return cuda_host_allocators_[0];
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ return cuda_host_allocators_[0].recording_allocator.get();
+ } else {
+ return cuda_host_allocators_[0].allocator.get();
+ }
}
void GPUProcessState::AddGPUAllocVisitor(int bus_id,
- const AllocVisitor& visitor) {
- CHECK(process_state_);
+ const SubAllocator::Visitor& visitor) {
#if GOOGLE_CUDA
mutex_lock lock(mu_);
- for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) {
- se::StreamExecutor* se =
- GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
- if (gpu_allocators_[i] &&
- (se->GetDeviceDescription().numa_node() + 1) == bus_id) {
- gpu_allocators_[i]->AddAllocVisitor(visitor);
- }
- }
+ CHECK(gpu_allocators_.empty()) // Crash OK
+ << "AddGPUAllocVisitor must be called before "
+ "first call to GetGPUAllocator.";
while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
- gpu_visitors_.push_back(std::vector<AllocVisitor>());
+ gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
}
gpu_visitors_[bus_id].push_back(visitor);
#endif // GOOGLE_CUDA
}
+void GPUProcessState::AddCUDAHostAllocVisitor(
+ int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ CHECK(cuda_host_allocators_.empty()) // Crash OK
+ << "AddCUDAHostAllocVisitor must be called before "
+ "first call to GetCUDAHostAllocator.";
+ while (numa_node >= static_cast<int64>(cuda_host_alloc_visitors_.size())) {
+ cuda_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+ }
+ cuda_host_alloc_visitors_[numa_node].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
+void GPUProcessState::AddCUDAHostFreeVisitor(
+ int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ CHECK(cuda_host_allocators_.empty()) // Crash OK
+ << "AddCUDAHostFreeVisitor must be called before "
+ "first call to GetCUDAHostAllocator.";
+ while (numa_node >= static_cast<int64>(cuda_host_free_visitors_.size())) {
+ cuda_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+ }
+ cuda_host_free_visitors_[numa_node].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
void GPUProcessState::TestOnlyReset() {
- process_state_->ProcessState::TestOnlyReset();
+ if (process_state_) {
+ process_state_->ProcessState::TestOnlyReset();
+ }
{
mutex_lock lock(mu_);
gpu_device_enabled_ = false;
+ gpu_allocators_.clear();
gpu_visitors_.clear();
- gtl::STLDeleteElements(&gpu_allocators_);
- gtl::STLDeleteElements(&cuda_host_allocators_);
- gtl::STLDeleteElements(&gpu_al_);
- gtl::STLDeleteElements(&cuda_al_);
+ cuda_host_allocators_.clear();
+ cuda_host_alloc_visitors_.clear();
+ cuda_host_free_visitors_.clear();
}
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
index cb41c3c6bd..43e9a31660 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
@@ -32,7 +32,6 @@ limitations under the License.
namespace tensorflow {
class Allocator;
-class VisitableAllocator;
class PoolAllocator;
// Singleton that manages per-process state when GPUs are present.
@@ -72,18 +71,30 @@ class GPUProcessState {
virtual Allocator* GetCUDAHostAllocator(int numa_node);
- // Registers a function to be called once on every new Region
- // allocated by every GPURegionAllocator proximate to the specified
- // bus. The AllocVisitor is provided with a memory pointer and the
- // size of the area it identifies. The pointer is not guaranteed to
- // be valid after the call terminates. The intention is for this
- // interface to be used for network device memory registration.
- // "bus_id" is platform-specific. On many platforms it
- // should be 0. On machines with multiple PCIe buses, it should be
- // the index of one of the PCIe buses. If the bus_id is invalid,
- // results are undefined.
- typedef std::function<void(void*, size_t)> AllocVisitor;
- virtual void AddGPUAllocVisitor(int bus_id, const AllocVisitor& visitor);
+ // Registers a Visitor to be invoked on new chunks of memory allocated by the
+ // SubAllocator of every GPU proximate to the specified bus. The AllocVisitor
+ // is provided with a memory pointer, a GPU id, and the size of the area it
+ // identifies. The pointer is not guaranteed to be valid after the call
+ // terminates. The intention is for this interface to be used for network
+ // device memory registration. "bus_id" is platform-specific. On many
+ // platforms it should be 0. On machines with multiple PCIe buses, it should
+ // be the index of one of the PCIe buses (maybe the NUMA node at which the
+ // PCIe is rooted). If the bus_id is invalid, results are undefined.
+ virtual void AddGPUAllocVisitor(int bus_id,
+ const SubAllocator::Visitor& visitor);
+
+ // Registers a Visitor to be invoked on new chunks of memory allocated by
+ // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+ virtual void AddCUDAHostAllocVisitor(int numa_node,
+ const SubAllocator::Visitor& visitor);
+
+ // Registers a Visitor to be invoked on each chunk handed back for freeing to
+ // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+ virtual void AddCUDAHostFreeVisitor(int numa_node,
+ const SubAllocator::Visitor& visitor);
+
+ // Returns bus_id for the given GPU id.
+ virtual int BusIdForGPU(TfGpuId tf_gpu_id);
protected:
GPUProcessState();
@@ -103,16 +114,21 @@ class GPUProcessState {
mutex mu_;
- std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
- std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
- std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_);
+ struct AllocatorParts {
+ std::unique_ptr<Allocator> allocator;
+ SubAllocator* sub_allocator; // owned by allocator
+ std::unique_ptr<Allocator> recording_allocator;
+ };
+ std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_);
- virtual ~GPUProcessState();
+ std::vector<AllocatorParts> cuda_host_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> cuda_host_alloc_visitors_
+ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> cuda_host_free_visitors_
+ GUARDED_BY(mu_);
- // Optional RecordingAllocators that wrap the corresponding
- // Allocators for runtime attribute use analysis.
- std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_);
- std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_);
+ virtual ~GPUProcessState();
friend class GPUDeviceTest;
};
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
index 583bff2c07..6b2f6547b0 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -31,7 +31,8 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) {
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/));
@@ -49,7 +50,8 @@ TEST(PoolAllocatorTest, ZeroSizePool) {
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ(0, pool.get_from_pool_count());
@@ -82,7 +84,8 @@ TEST(PoolAllocatorTest, Alignment) {
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
for (int i = 0; i < 16; ++i) {
size_t alignment = 1 << i;
@@ -97,8 +100,8 @@ TEST(PoolAllocatorTest, Alignment) {
TEST(PoolAllocatorTest, AutoResize) {
PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(0 /*numa_node*/), new NoopRounder,
- "pool");
+ new BasicCPUAllocator(0 /*numa_node*/, {}, {}),
+ new NoopRounder, "pool");
// Alloc/dealloc 10 sizes just a few times, confirming pool size
// stays at 2.
@@ -123,14 +126,32 @@ TEST(PoolAllocatorTest, AutoResize) {
}
TEST(PoolAllocatorTest, CudaHostAllocator) {
+ int alloc_count = 0;
+ int64 alloc_size = 0;
+ SubAllocator::Visitor alloc_visitor =
+ [&alloc_count, &alloc_size](void* ptr, int numa_node, int64 size) {
+ ++alloc_count;
+ alloc_size += size;
+ };
+ int free_count = 0;
+ int64 free_size = 0;
+ SubAllocator::Visitor free_visitor =
+ [&free_count, &free_size](void* ptr, int numa_node, int64 size) {
+ ++free_count;
+ free_size += size;
+ };
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
- PoolAllocator pool(
- 2 /*pool_size_limit*/, false /*auto_resize*/,
- new CUDAHostAllocator(
- platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
- new NoopRounder, "pool");
+ CUDAHostAllocator* sub_allocator = new CUDAHostAllocator(
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie(),
+ 0 /*numa_node*/, {alloc_visitor}, {free_visitor});
+ PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/,
+ sub_allocator, new NoopRounder, "pool");
+ EXPECT_EQ(0, alloc_count);
+ EXPECT_EQ(0, alloc_size);
+ EXPECT_EQ(0, free_count);
+ EXPECT_EQ(0, free_size);
// Repeatedly Get a 16-byte value, confirming that there's only
// one real allocation.
@@ -138,6 +159,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
EXPECT_EQ(0, pool.get_from_pool_count());
EXPECT_EQ(1, pool.allocated_count());
EXPECT_NE(nullptr, p1_16);
+ EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes
+ // Each suballocation includes a 16B ChunkPrefix.
+ static const int kChunkPrefixSize = 16;
+ EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
pool.DeallocateRaw(p1_16);
// Pool contents {16}
EXPECT_EQ(1, pool.put_count());
@@ -148,6 +173,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
pool.DeallocateRaw(p2_16); // Put it back.
// Pool contents {16}
EXPECT_EQ(2, pool.put_count());
+ EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes
+ EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(0, free_count);
// Get two more values of different sizes.
void* p3_4 = pool.AllocateRaw(4, 4);
@@ -160,6 +188,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer.
EXPECT_NE(nullptr, p4_2);
EXPECT_EQ(0, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(0, free_count);
// The pool is full: when we put back p4_2, the 16-byte buffer
// should be evicted since it was least recently inserted.
@@ -167,6 +198,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
// Pool contents {2, 4}
EXPECT_EQ(4, pool.put_count());
EXPECT_EQ(1, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(1, free_count);
+ EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
// Re-getting and putting size 2 or 4 should not alter pool size or
// num-evicted.
@@ -180,12 +215,20 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
EXPECT_EQ(6, pool.put_count());
EXPECT_EQ(3, pool.allocated_count());
EXPECT_EQ(1, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(1, free_count);
+ EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
pool.Clear();
EXPECT_EQ(0, pool.get_from_pool_count());
EXPECT_EQ(0, pool.put_count());
EXPECT_EQ(0, pool.allocated_count());
EXPECT_EQ(0, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(3, free_count);
+ EXPECT_EQ(16 + 4 + 2 + (free_count * kChunkPrefixSize), free_size);
}
TEST(PoolAllocatorTest, Pow2Rounder) {
@@ -206,7 +249,8 @@ TEST(PoolAllocatorTest, Name) {
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ("pool", pool.Name());
}
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
index db5022d56e..873182371e 100644
--- a/tensorflow/core/common_runtime/local_device.cc
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -62,7 +62,7 @@ struct LocalDevice::EigenThreadPoolInfo {
LocalDevice::LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes)
- : TracingDevice(options.env, attributes), owned_tp_info_(nullptr) {
+ : Device(options.env, attributes), owned_tp_info_(nullptr) {
// Log info messages if TensorFlow is not compiled with instructions that
// could speed up performance and are available on the current CPU.
port::InfoAboutUnusedCPUFeatures();
diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h
index 9a82fb7204..226f121bf3 100644
--- a/tensorflow/core/common_runtime/local_device.h
+++ b/tensorflow/core/common_runtime/local_device.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_
#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/common_runtime/tracing_device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/platform/macros.h"
@@ -32,7 +31,7 @@ struct SessionOptions;
// initializes a shared Eigen compute device used by both. This
// should eventually be removed once we refactor ThreadPoolDevice and
// GPUDevice into more 'process-wide' abstractions.
-class LocalDevice : public TracingDevice {
+class LocalDevice : public Device {
public:
LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes);
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index df9c3a686c..538a70668a 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -23,12 +23,11 @@ limitations under the License.
#include <cstdlib>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
-#include "tensorflow/core/framework/allocator_registry.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
-#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/numa.h"
#ifndef INTEL_MKL_DNN_ONLY
#include "i_malloc.h"
@@ -40,20 +39,16 @@ typedef unsigned int uint;
namespace tensorflow {
-class MklSubAllocator : public SubAllocator {
+class MklSubAllocator : public BasicCPUAllocator {
public:
+ MklSubAllocator() : BasicCPUAllocator(port::kNUMANoAffinity, {}, {}) {}
~MklSubAllocator() override {}
-
- void* Alloc(size_t alignment, size_t num_bytes) override {
- return port::AlignedMalloc(num_bytes, alignment);
- }
- void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
};
// CPU allocator that handles small-size allocations by calling
// suballocator directly. Mostly, it is just a wrapper around a suballocator
// (that calls malloc and free directly) with support for bookkeeping.
-class MklSmallSizeAllocator : public VisitableAllocator {
+class MklSmallSizeAllocator : public Allocator {
public:
MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory,
const string& name)
@@ -75,10 +70,6 @@ class MklSmallSizeAllocator : public VisitableAllocator {
CHECK(map_.insert(map_val).second);
// Increment statistics for small-size allocations.
IncrementStats(num_bytes);
- // Call alloc visitors.
- for (const auto& visitor : alloc_visitors_) {
- visitor(ptr, num_bytes);
- }
}
return ptr;
}
@@ -94,9 +85,6 @@ class MklSmallSizeAllocator : public VisitableAllocator {
if (map_iter != map_.end()) {
// Call free visitors.
size_t dealloc_bytes = map_iter->second;
- for (const auto& visitor : free_visitors_) {
- visitor(ptr, dealloc_bytes);
- }
sub_allocator_->Free(ptr, dealloc_bytes);
DecrementStats(dealloc_bytes);
map_.erase(map_iter);
@@ -121,16 +109,6 @@ class MklSmallSizeAllocator : public VisitableAllocator {
stats_.Clear();
}
- void AddAllocVisitor(Visitor visitor) override {
- mutex_lock l(mutex_);
- alloc_visitors_.push_back(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- mutex_lock l(mutex_);
- free_visitors_.push_back(visitor);
- }
-
private:
// Increment statistics for the allocator handling small allocations.
inline void IncrementStats(size_t alloc_size)
@@ -163,15 +141,11 @@ class MklSmallSizeAllocator : public VisitableAllocator {
// Allocator stats for small allocs
AllocatorStats stats_ GUARDED_BY(mutex_);
-
- // Visitors
- std::vector<Visitor> alloc_visitors_ GUARDED_BY(mutex_);
- std::vector<Visitor> free_visitors_ GUARDED_BY(mutex_);
};
/// CPU allocator for MKL that wraps BFC allocator and intercepts
/// and redirects memory allocation calls from MKL.
-class MklCPUAllocator : public VisitableAllocator {
+class MklCPUAllocator : public Allocator {
public:
// Constructor and other standard functions
@@ -284,16 +258,6 @@ class MklCPUAllocator : public VisitableAllocator {
large_size_allocator_->ClearStats();
}
- void AddAllocVisitor(Visitor visitor) override {
- small_size_allocator_->AddAllocVisitor(visitor);
- large_size_allocator_->AddAllocVisitor(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- small_size_allocator_->AddFreeVisitor(visitor);
- large_size_allocator_->AddFreeVisitor(visitor);
- }
-
private:
// Hooks provided by this allocator for memory allocation routines from MKL
@@ -330,7 +294,7 @@ class MklCPUAllocator : public VisitableAllocator {
// The alignment that we need for the allocations
static constexpr const size_t kAlignment = 64;
- VisitableAllocator* large_size_allocator_; // owned by this class
+ Allocator* large_size_allocator_; // owned by this class
MklSmallSizeAllocator* small_size_allocator_; // owned by this class.
SubAllocator* sub_allocator_; // not owned by this class
diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
index f9f36443a8..6af4ca4d96 100644
--- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
+++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
@@ -50,8 +50,8 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
}
for (Node* n : matches) {
AttrSlice n_attrs = n->attrs();
- auto base_make_node = [n, g, &n_attrs](const string& op,
- const string& name) {
+ auto base_make_node = [n, &n_attrs](const string& op,
+ const string& name) {
NodeBuilder node_builder(name, op);
node_builder.Device(n->requested_device());
string colo;
@@ -60,7 +60,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
}
return node_builder;
};
- auto make_node = [n, g, &n_attrs, &base_make_node](string op) {
+ auto make_node = [n, g, &base_make_node](string op) {
return base_make_node(
op, g->NewName(strings::StrCat(n->name(), "/Internal")));
};
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index fdad8de8d6..66dc8f3322 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -40,8 +40,7 @@ PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
auto_resize_(auto_resize),
pool_size_limit_(pool_size_limit),
allocator_(allocator),
- size_rounder_(size_rounder),
- allocation_begun_(false) {
+ size_rounder_(size_rounder) {
if (auto_resize) {
CHECK_LT(size_t{0}, pool_size_limit)
<< "size limit must be > 0 if auto_resize is true.";
@@ -93,7 +92,6 @@ ChunkPrefix* FindPrefix(void* user_ptr) {
} // namespace
void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
- if (!allocation_begun_) allocation_begun_ = true;
if (num_bytes == 0) return nullptr;
// If alignment is larger than kPoolAlignment, increase num_bytes so that we
@@ -129,9 +127,6 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
return PrepareChunk(r, alignment, num_bytes);
} else {
void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
- for (const auto& v : alloc_visitors_) {
- v(ptr, num_bytes);
- }
return PrepareChunk(ptr, alignment, num_bytes);
}
}
@@ -141,9 +136,6 @@ void PoolAllocator::DeallocateRaw(void* ptr) {
ChunkPrefix* cp = FindPrefix(ptr);
CHECK_LE((void*)cp, (void*)ptr);
if (!has_size_limit_ && !auto_resize_) {
- for (const auto& v : free_visitors_) {
- v(cp, cp->num_bytes);
- }
allocator_->Free(cp, cp->num_bytes);
} else {
mutex_lock lock(mutex_);
@@ -164,9 +156,6 @@ void PoolAllocator::Clear() {
mutex_lock lock(mutex_);
for (auto iter : pool_) {
PtrRecord* pr = iter.second;
- for (const auto& v : free_visitors_) {
- v(pr->ptr, pr->num_bytes);
- }
allocator_->Free(pr->ptr, pr->num_bytes);
delete pr;
}
@@ -221,9 +210,6 @@ void PoolAllocator::EvictOne() {
DCHECK(iter != pool_.end());
}
pool_.erase(iter);
- for (const auto& v : free_visitors_) {
- v(prec->ptr, prec->num_bytes);
- }
allocator_->Free(prec->ptr, prec->num_bytes);
delete prec;
++evicted_count_;
@@ -269,28 +255,19 @@ void PoolAllocator::EvictOne() {
}
}
-void PoolAllocator::AddAllocVisitor(Visitor visitor) {
- mutex_lock lock(mutex_);
- CHECK(!allocation_begun_)
- << "AddAllocVisitor may not be called after pool allocation "
- << "has begun.";
- alloc_visitors_.push_back(visitor);
-}
-
-void PoolAllocator::AddFreeVisitor(Visitor visitor) {
- mutex_lock lock(mutex_);
- CHECK(!allocation_begun_)
- << "AddFreeVisitor may not be called after pool allocation "
- << "has begun.";
- free_visitors_.push_back(visitor);
-}
-
void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) {
- return port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+ void* ptr = nullptr;
+ if (num_bytes > 0) {
+ ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+ VisitAlloc(ptr, numa_node_, num_bytes);
+ }
+ return ptr;
}
void BasicCPUAllocator::Free(void* ptr, size_t num_bytes) {
- port::AlignedFree(ptr);
+ if (num_bytes > 0) {
+ VisitFree(ptr, numa_node_, num_bytes);
+ port::AlignedFree(ptr);
+ }
}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h
index 607734445b..5b4623ba10 100644
--- a/tensorflow/core/common_runtime/pool_allocator.h
+++ b/tensorflow/core/common_runtime/pool_allocator.h
@@ -16,14 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
-// Simple LRU pool allocators for various flavors of CPU RAM that
-// implement the VisitableAllocator interface.
+// Simple LRU pool allocators for various flavors of CPU RAM.
#include <atomic>
#include <map>
#include <memory>
#include <vector>
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -41,7 +40,7 @@ class RoundUpInterface {
// Size-limited pool of memory buffers obtained from a SubAllocator
// instance. Pool eviction policy is LRU.
-class PoolAllocator : public VisitableAllocator {
+class PoolAllocator : public Allocator {
public:
// "pool_size_limit" is the maximum number of returned, re-usable
// memory buffers to keep in the pool. If pool_size_limit == 0, the
@@ -64,14 +63,6 @@ class PoolAllocator : public VisitableAllocator {
void DeallocateRaw(void* ptr) override;
- // REQUIRES: The following functions may only be called prior
- // to the first Allocate*() call. Once allocation has begun, it is
- // illegal to register another visitor.
-
- void AddAllocVisitor(Visitor visitor) override;
-
- void AddFreeVisitor(Visitor visitor) override;
-
// Allocate an unused memory region of size "num_bytes". Fetch from
// the pool if available, otherwise call allocator_.
void* Get(size_t num_bytes);
@@ -141,12 +132,6 @@ class PoolAllocator : public VisitableAllocator {
int64 put_count_ GUARDED_BY(mutex_) = 0;
int64 allocated_count_ GUARDED_BY(mutex_) = 0;
int64 evicted_count_ GUARDED_BY(mutex_) = 0;
- // Write access to these is guarded by mutex_, but not read
- // access. They may only be modified prior to the first
- // allocation. Later attempts to modify will fail.
- std::vector<Visitor> alloc_visitors_;
- std::vector<Visitor> free_visitors_;
- std::atomic<bool> allocation_begun_;
};
// Do-nothing rounder. Passes through sizes unchanged.
@@ -166,7 +151,9 @@ class Pow2Rounder : public RoundUpInterface {
class BasicCPUAllocator : public SubAllocator {
public:
// Argument numa_node is currently ignored.
- explicit BasicCPUAllocator(int numa_node) : numa_node_(numa_node) {}
+ BasicCPUAllocator(int numa_node, const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors), numa_node_(numa_node) {}
~BasicCPUAllocator() override {}
@@ -176,6 +163,8 @@ class BasicCPUAllocator : public SubAllocator {
private:
int numa_node_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BasicCPUAllocator);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc
index 447338e7bd..bcaa37fc8a 100644
--- a/tensorflow/core/common_runtime/process_state.cc
+++ b/tensorflow/core/common_runtime/process_state.cc
@@ -71,20 +71,28 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
return MemDesc();
}
-VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
+Allocator* ProcessState::GetCPUAllocator(int numa_node) {
CHECK_GE(numa_node, 0);
if (!numa_enabled_) numa_node = 0;
mutex_lock lock(mu_);
while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
+ // If visitors have been defined we need an Allocator built from
+ // a SubAllocator. Prefer BFCAllocator, but fall back to PoolAllocator
+ // depending on env var setting.
+ const bool alloc_visitors_defined =
+ (!cpu_alloc_visitors_.empty() || !cpu_free_visitors_.empty());
bool use_bfc_allocator = false;
- // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and
- // efficient.
- Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false,
- &use_bfc_allocator);
+ Status status = ReadBoolFromEnvVar(
+ "TF_CPU_ALLOCATOR_USE_BFC", alloc_visitors_defined, &use_bfc_allocator);
if (!status.ok()) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
}
- VisitableAllocator* allocator;
+ Allocator* allocator = nullptr;
+ SubAllocator* sub_allocator =
+ (alloc_visitors_defined || use_bfc_allocator)
+ ? new BasicCPUAllocator(numa_enabled_ ? numa_node : -1,
+ cpu_alloc_visitors_, cpu_free_visitors_)
+ : nullptr;
if (use_bfc_allocator) {
// TODO(reedwm): evaluate whether 64GB by default is the best choice.
int64 cpu_mem_limit_in_mb = -1;
@@ -95,34 +103,63 @@ VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
}
int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20);
- allocator = new BFCAllocator(
- new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), cpu_mem_limit,
- true /*allow_growth*/, "bfc_cpu_allocator_for_gpu" /*name*/);
+ DCHECK(sub_allocator);
+ allocator =
+ new BFCAllocator(sub_allocator, cpu_mem_limit, true /*allow_growth*/,
+ "bfc_cpu_allocator_for_gpu" /*name*/);
VLOG(2) << "Using BFCAllocator with memory limit of "
<< cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator";
- } else {
- allocator = new PoolAllocator(
- 100 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(numa_enabled_ ? numa_node : -1),
- new NoopRounder, "cpu_pool");
+ } else if (alloc_visitors_defined) {
+ DCHECK(sub_allocator);
+ allocator =
+ new PoolAllocator(100 /*pool_size_limit*/, true /*auto_resize*/,
+ sub_allocator, new NoopRounder, "cpu_pool");
VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator "
<< "numa_enabled_=" << numa_enabled_
<< " numa_node=" << numa_node;
+ } else {
+ DCHECK(!sub_allocator);
+ allocator = cpu_allocator();
}
- if (LogMemory::IsEnabled()) {
+ if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
+ allocator = new TrackingAllocator(allocator, true);
}
cpu_allocators_.push_back(allocator);
+ if (!sub_allocator) {
+ DCHECK(cpu_alloc_visitors_.empty() && cpu_free_visitors_.empty());
+ }
}
return cpu_allocators_[numa_node];
}
+void ProcessState::AddCPUAllocVisitor(SubAllocator::Visitor visitor) {
+ VLOG(1) << "AddCPUAllocVisitor";
+ mutex_lock lock(mu_);
+ CHECK_EQ(0, cpu_allocators_.size()) // Crash OK
+ << "AddCPUAllocVisitor must be called prior to first call to "
+ "ProcessState::GetCPUAllocator";
+ cpu_alloc_visitors_.push_back(std::move(visitor));
+}
+
+void ProcessState::AddCPUFreeVisitor(SubAllocator::Visitor visitor) {
+ mutex_lock lock(mu_);
+ CHECK_EQ(0, cpu_allocators_.size()) // Crash OK
+ << "AddCPUFreeVisitor must be called prior to first call to "
+ "ProcessState::GetCPUAllocator";
+ cpu_free_visitors_.push_back(std::move(visitor));
+}
+
void ProcessState::TestOnlyReset() {
mutex_lock lock(mu_);
+ // Don't delete this value because it's static.
+ Allocator* default_cpu_allocator = cpu_allocator();
mem_desc_map_.clear();
- gtl::STLDeleteElements(&cpu_allocators_);
+ for (Allocator* a : cpu_allocators_) {
+ if (a != default_cpu_allocator) delete a;
+ }
+ cpu_allocators_.clear();
gtl::STLDeleteElements(&cpu_al_);
}
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
index 2892677333..cac312d849 100644
--- a/tensorflow/core/common_runtime/process_state.h
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -30,7 +30,6 @@ limitations under the License.
namespace tensorflow {
class Allocator;
-class VisitableAllocator;
class PoolAllocator;
// Singleton that manages per-process state, e.g. allocation of
@@ -65,7 +64,15 @@ class ProcessState {
// Returns the one CPUAllocator used for the given numa_node.
// TEMPORARY: ignores numa_node.
- VisitableAllocator* GetCPUAllocator(int numa_node);
+ Allocator* GetCPUAllocator(int numa_node);
+
+ // Registers alloc visitor for the CPU allocator(s).
+ // REQUIRES: must be called before GetCPUAllocator.
+ void AddCPUAllocVisitor(SubAllocator::Visitor v);
+
+ // Registers free visitor for the CPU allocator(s).
+ // REQUIRES: must be called before GetCPUAllocator.
+ void AddCPUFreeVisitor(SubAllocator::Visitor v);
typedef std::unordered_map<const void*, MemDesc> MDMap;
@@ -87,7 +94,9 @@ class ProcessState {
mutex mu_;
- std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<SubAllocator::Visitor> cpu_alloc_visitors_ GUARDED_BY(mu_);
+ std::vector<SubAllocator::Visitor> cpu_free_visitors_ GUARDED_BY(mu_);
virtual ~ProcessState();
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index 103eee03b3..c00789a556 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -58,6 +58,15 @@ class RenamedDevice : public Device {
return underlying_->GetAllocator(attr);
}
+ Allocator* GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) override {
+ return underlying_->GetScopedAllocator(attr, step_id);
+ }
+
+ ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
+ return underlying_->GetScopedAllocatorMgr();
+ }
+
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
return underlying_->eigen_cpu_device();
}
@@ -72,9 +81,10 @@ class RenamedDevice : public Device {
return underlying_->MakeGpuDevice();
}
- void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
- DeviceContext* dc, Allocator* allocator) override {
- underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
+ Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) override {
+ return underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
}
Status MakeTensorFromProto(const TensorProto& tensor_proto,
diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc
index 1e3fed0d6f..43ca3f1e3e 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.cc
+++ b/tensorflow/core/common_runtime/rendezvous_util.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/reffed_status_callback.h"
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc
deleted file mode 100644
index b931ef4229..0000000000
--- a/tensorflow/core/common_runtime/session_ref.cc
+++ /dev/null
@@ -1,170 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/session_ref.h"
-
-#include <utility>
-
-namespace tensorflow {
-
-namespace {
-
-// Scope helper to track active calls and manage session lifetime.
-struct RunCounter {
- std::shared_ptr<Session> session;
- uint64* value;
- mutex* m;
- condition_variable* cv;
-
- explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
- condition_variable* cv)
- : session(std::move(s)), value(v), m(m), cv(cv) {
- mutex_lock l(*m);
- ++*value;
- }
-
- ~RunCounter() {
- mutex_lock l(*m);
- if (--*value == 0) {
- cv->notify_all();
- }
- }
-};
-
-} // namespace
-
-Status SessionRef::CheckNotClosed() {
- mutex_lock l(run_lock_);
- if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
- return ::tensorflow::Status::OK();
-}
-
-Status SessionRef::Run(const RunOptions& run_options,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(run_options, inputs, output_tensor_names,
- target_node_names, outputs, run_metadata);
-}
-
-Status SessionRef::Create(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(graph);
-}
-
-Status SessionRef::Create(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(run_options, graph);
-}
-
-Status SessionRef::Extend(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(run_options, graph);
-}
-
-Status SessionRef::Extend(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(graph);
-}
-
-Status SessionRef::Close(const RunOptions& run_options) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close(run_options);
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Close() {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close();
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(inputs, output_tensor_names, target_node_names,
- outputs);
-}
-
-Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ListDevices(response);
-}
-
-Status SessionRef::PRunSetup(const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- string* handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRunSetup(input_names, output_names, target_nodes, handle);
-}
-
-Status SessionRef::PRun(const string& handle,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRun(handle, inputs, output_names, outputs);
-}
-
-Status SessionRef::MakeCallable(const CallableOptions& callable_options,
- CallableHandle* out_handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->MakeCallable(callable_options, out_handle);
-}
-
-Status SessionRef::RunCallable(CallableHandle handle,
- const std::vector<Tensor>& feed_tensors,
- std::vector<Tensor>* fetch_tensors,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->RunCallable(handle, feed_tensors, fetch_tensors,
- run_metadata);
-}
-
-Status SessionRef::ReleaseCallable(CallableHandle handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ReleaseCallable(handle);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
index 04d5af9087..22650b0d83 100644
--- a/tensorflow/core/common_runtime/single_threaded_cpu_device.h
+++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 836cb8ed14..a70ab93d4a 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
@@ -40,46 +41,24 @@ struct AllocStats {
};
} // namespace
-NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name)
- : NodeExecStatsWrapper(new NodeExecStats) {
- stats_->set_node_name(node_name);
-}
-NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
- : stats_(stats) {}
-
-void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) {
- DCHECK(v);
- NodeOutput* no = stats_->add_output();
- no->set_slot(slot);
- v->FillDescription(no->mutable_tensor_description());
-}
-
-void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
- for (const auto& allocator_pair : ctx->wrapped_allocators()) {
- AddAllocation(allocator_pair.first, allocator_pair.second);
- }
- auto* ms = stats_->mutable_memory_stats();
- ms->set_temp_memory_size(ctx->temp_memory_allocated());
- for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
- ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
- }
- ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ const Node* node, StepStatsCollector* step_stats_collector)
+ : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node,
+ step_stats_collector) {
+ stats_->set_node_name(node->name());
}
-void NodeExecStatsWrapper::SetReferencedTensors(
- const TensorReferenceVector& tensors) {
- // be careful not to increment the reference count on any tensor
- // while recording the information
- for (size_t i = 0; i < tensors.size(); ++i) {
- AllocationDescription* description = stats_->add_referenced_tensor();
- tensors.at(i).FillDescription(description);
- }
-}
-
-// TODO(tucker): merge with the DetailText function in session.cc
-// in a common location.
-bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
- bool is_transfer_node = false;
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector)
+ : stats_(std::move(stats)),
+ node_(node),
+ step_stats_collector_(step_stats_collector) {}
+
+void NodeExecStatsWrapper::Done(const string& device) {
+ // TODO(tucker): merge with the DetailText function in session.cc in a common
+ // location.
+ DCHECK(node_);
string memory;
for (auto& all : stats_->memory()) {
int64 tot = all.total_bytes();
@@ -96,31 +75,96 @@ bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) {
}
}
}
- const AttrSlice attrs = node->attrs();
+ const AttrSlice attrs = node_->attrs();
string text;
- if (IsSend(node)) {
+ if (IsSend(node_)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
string recv_device;
TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
"(", tensor_name, " @", recv_device);
- is_transfer_node = true;
- } else if (IsRecv(node)) {
+ } else if (IsRecv(node_)) {
string tensor_name;
TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
string send_device;
TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
"(", tensor_name, " @", send_device);
- is_transfer_node = true;
} else {
text =
- strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
- str_util::Join(node->requested_inputs(), ", "), ")");
+ strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(",
+ str_util::Join(node_->requested_inputs(), ", "), ")");
}
stats_->set_timeline_label(text);
- return is_transfer_node;
+ step_stats_collector_->Save(device, this);
+}
+
+void NodeExecStatsWrapper::RecordExecutorStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ stats_->set_all_start_nanos(now_nanos);
+}
+
+void NodeExecStatsWrapper::RecordComputeStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordComputeEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordExecutorEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::SetScheduled(int64 nanos) {
+ stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+ stats_->set_scheduled_nanos(nanos);
+}
+
+void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AddAllocation(allocator_pair.first, allocator_pair.second);
+ }
+ auto* ms = stats_->mutable_memory_stats();
+ ms->set_temp_memory_size(ctx->temp_memory_allocated());
+ for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
+ ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+}
+
+void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) {
+ DCHECK(tensor);
+ NodeOutput* node_output = stats_->add_output();
+ node_output->set_slot(slot);
+ tensor->FillDescription(node_output->mutable_tensor_description());
+}
+
+void NodeExecStatsWrapper::SetReferencedTensors(
+ const TensorReferenceVector& tensors) {
+ // be careful not to increment the reference count on any tensor
+ // while recording the information
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ AllocationDescription* description = stats_->add_referenced_tensor();
+ tensors.at(i).FillDescription(description);
+ }
}
void NodeExecStatsWrapper::AddAllocation(
@@ -150,8 +194,8 @@ void NodeExecStatsWrapper::Finalize() {
allocations_.clear();
}
-StepStatsCollector::StepStatsCollector(StepStats* ss)
- : finalized_(false), step_stats_(ss) {}
+StepStatsCollector::StepStatsCollector(StepStats* step_stats)
+ : finalized_(false), step_stats_(step_stats) {}
static int ExtractGpuWithStreamAll(string device_name) {
// Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp,
@@ -338,28 +382,40 @@ void StepStatsCollector::BuildCostModel(
}
}
-void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
- Save(device, new NodeExecStatsWrapper(nt));
+void StepStatsCollector::Save(const string& device,
+ NodeExecStats* node_stats_pb) {
+ Save(device,
+ new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb),
+ nullptr, this));
}
void StepStatsCollector::Save(const string& device,
- NodeExecStatsWrapper* stats) {
- if (!stats) return;
- VLOG(1) << "Save dev " << device << " nt " << stats->stats();
+ NodeExecStatsWrapper* node_stats) {
+ if (!node_stats) return;
+ VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats();
{
mutex_lock l(mu_);
if (finalized_) {
LOG(WARNING) << "stats saved after finalize will not be collected.";
}
- if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
+ if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) {
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
- delete stats;
+ delete node_stats;
return;
}
- auto& dss = dev_stats_[device];
- dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats));
- collectedNodes++;
+ auto& device_stats = dev_stats_[device];
+ device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats));
+ collected_nodes_++;
+ }
+}
+
+NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats(
+ const Node* node) {
+ // Only collect statistics for non-transfer nodes.
+ if (IsSend(node) || IsRecv(node)) {
+ return nullptr;
}
+ return new NodeExecStatsWrapper(node, this);
}
string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) {
@@ -446,12 +502,12 @@ void StepStatsCollector::Finalize() {
FinalizeInternal();
}
-void StepStatsCollector::FinalizeAndSwap(StepStats* ss) {
+void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) {
mutex_lock l(mu_);
CHECK(step_stats_);
FinalizeInternal();
- ss->Swap(step_stats_);
- collectedNodes = 0;
+ step_stats->Swap(step_stats_);
+ collected_nodes_ = 0;
}
void StepStatsCollector::FinalizeInternal() {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 7206fbf427..4365b11b19 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -36,81 +36,78 @@ class Node;
class NodeExecStats;
class OpKernelContext;
class StepStats;
+class StepStatsCollector;
class Tensor;
class TrackingAllocator;
-// Wraps NodeExecStats and adds allocation to it.
-class NodeExecStatsWrapper {
+// Statistics collection interface for individual node execution.
+//
+// See `NodeExecStatsWrapper` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class NodeExecStatsInterface {
public:
- NodeExecStatsWrapper(const string& node_name);
- // Owns 'stats'.
- NodeExecStatsWrapper(NodeExecStats* stats);
+ virtual ~NodeExecStatsInterface() {}
- // Destructor calls Finalize() to release the TrackingAllocators.
- ~NodeExecStatsWrapper() { Finalize(); }
-
- // Records the absolute time in nanoseconds at which this node became
- // runnable (i.e. was scheduled for execution).
- void SetScheduled(int64 nanos) {
- stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
- stats_->set_scheduled_nanos(nanos);
- }
+ // Called when the statistics collection for the node has finished. Once this
+ // method is called, the caller should not make assumptions about the validity
+ // of this object.
+ virtual void Done(const string& device) = 0;
// Called immediately after this node starts being processed by the executor.
- void RecordExecutorStarted() {
- int64 now_nanos = Env::Default()->NowNanos();
- stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
- stats_->set_all_start_nanos(now_nanos);
- }
+ virtual void RecordExecutorStarted() = 0;
// Called immediately before this node's `Compute()` or `ComputeAsync()`
// method is called.
- void RecordComputeStarted() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
+ virtual void RecordComputeStarted() = 0;
// Called immediately after this node's `Compute()` method returned (or, for
// asynchronous operations, the callback passed to its `ComputeAsync()` method
// was called).
- void RecordComputeEnded() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
+ virtual void RecordComputeEnded() = 0;
// Called immediately after this executor finishes processing this node.
- void RecordExecutorEnded() {
- int64 now_nanos = Env::Default()->NowNanos();
- DCHECK_NE(stats_->all_start_micros(), 0);
- DCHECK_NE(stats_->all_start_nanos(), 0);
- stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- stats_->all_start_micros());
- stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
- }
-
- // Records information about the tensor produced by this node at the given
- // output slot.
- void SetOutput(int slot, const Tensor* v);
+ virtual void RecordExecutorEnded() = 0;
// Records information about the memory allocated during the execution of this
// node.
- void SetMemory(OpKernelContext* ctx);
+ virtual void SetMemory(OpKernelContext* ctx) = 0;
+
+ // Records information about the tensor produced by this node at the given
+ // output slot.
+ virtual void SetOutput(int slot, const Tensor* tensor) = 0;
// Records information about the tensors that were accessed during the
// execution of this node.
- void SetReferencedTensors(const TensorReferenceVector& tensors);
+ virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0;
- // Sets the timeline_label field of the wrapped NodeExecStats, using data
- // from *node. Returns true iff the node is a transfer node.
- bool SetTimelineLabel(const Node* node);
+ // Records the absolute time in nanoseconds at which this node became
+ // runnable (i.e. was scheduled for execution).
+ virtual void SetScheduled(int64 nanos) = 0;
+};
+
+// Wraps NodeExecStats and adds allocation to it.
+class NodeExecStatsWrapper : public NodeExecStatsInterface {
+ public:
+ // Does not take ownership of `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Takes ownership of 'stats' but not `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Destructor calls Finalize() to release the TrackingAllocators.
+ ~NodeExecStatsWrapper() { Finalize(); }
+
+ void Done(const string& device) override;
+ void RecordExecutorStarted() override;
+ void RecordComputeStarted() override;
+ void RecordComputeEnded() override;
+ void RecordExecutorEnded() override;
+ void SetMemory(OpKernelContext* ctx) override;
+ void SetOutput(int slot, const Tensor* tensor) override;
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override;
+ void SetScheduled(int64 nanos) override;
private:
friend class StepStatsCollector;
@@ -128,9 +125,11 @@ class NodeExecStatsWrapper {
gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
allocations_;
std::unique_ptr<NodeExecStats> stats_;
+ const Node* const node_; // Not owned.
+ StepStatsCollector* const step_stats_collector_; // Not owned.
};
-// Statistics collection interface for individual node execution.
+// Statistics collection interface for step execution.
//
// See `StepStatsCollector` for a concrete implementation of this interface
// that interfaces with the `Session` layer.
@@ -138,8 +137,9 @@ class StepStatsCollectorInterface {
public:
virtual ~StepStatsCollectorInterface() {}
- // Saves `stats` to the collector.
- virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0;
+ // Creates an instance of `NodeExecStatsInterface` that should be used for
+ // collecting statistics about individual node execution.
+ virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0;
// Generates a string reporting the currently used memory based
// on ResourceExhausted OOM `err` message.
@@ -154,8 +154,8 @@ class StepStatsCollectorInterface {
// Each DeviceStats object holds multiple NodeExecStats.
class StepStatsCollector : public StepStatsCollectorInterface {
public:
- // Does not take ownership of `ss`.
- explicit StepStatsCollector(StepStats* ss);
+ // Does not take ownership of `step_stats`.
+ explicit StepStatsCollector(StepStats* step_stats);
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
// using the currently collected DeviceStats associated with the devices in
@@ -164,11 +164,12 @@ class StepStatsCollector : public StepStatsCollectorInterface {
CostModelManager* cost_model_manager,
const std::unordered_map<string, const Graph*>& device_map);
- // Save saves nt to the DeviceStats object associated with device.
+ // Saves node statistics to the DeviceStats object associated with device.
// Should be called before Finalize.
- void Save(const string& device, NodeExecStats* nt);
- void Save(const string& device, NodeExecStatsWrapper* stats) override;
+ void Save(const string& device, NodeExecStats* node_stats_pb);
+ void Save(const string& device, NodeExecStatsWrapper* node_stats);
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override;
string ReportAllocsOnResourceExhausted(const string& err) override;
// The following 2 Finalize methods populate the StepStats passed
@@ -176,20 +177,22 @@ class StepStatsCollector : public StepStatsCollectorInterface {
// User shouldn't call Save() methods after Finalize.
void Finalize();
// swaps the content of StepStats* from constructor with 'ss'.
- void FinalizeAndSwap(StepStats* ss);
+ void FinalizeAndSwap(StepStats* step_stats);
private:
+ // TODO(suharshs): Make this configurable if its not possible to find a value
+ // that works for all cases.
+ static const uint64 kMaxCollectedNodes = 1 << 20;
+
+ typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector;
+
void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_);
- typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec;
- // TODO(suharshs): Make this configurable if its not possible to find a value
- // that works for all cases.
- const uint64 kMaxCollectedNodes = 1 << 20;
mutex mu_;
bool finalized_ GUARDED_BY(mu_);
- std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_);
+ std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_);
StepStats* step_stats_ GUARDED_BY(mu_);
- uint64 collectedNodes GUARDED_BY(mu_) = 0;
+ uint64 collected_nodes_ GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h
deleted file mode 100644
index e1b163074f..0000000000
--- a/tensorflow/core/common_runtime/tracing_device.h
+++ /dev/null
@@ -1,60 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
-
-#include "tensorflow/core/common_runtime/device.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/tracing.h"
-
-namespace tensorflow {
-
-namespace test {
-class Benchmark;
-}
-struct SessionOptions;
-
-// This class implements tracing functionality that is shared by its subclasses
-// (including ThreadPoolDevice and XlaDevice).
-class TracingDevice : public Device {
- public:
- TracingDevice(Env* env, const DeviceAttributes& attributes)
- : Device(env, attributes) {}
-
- void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
- const tracing::TraceCollector* trace_collector =
- tracing::GetTraceCollector();
- if (TF_PREDICT_FALSE(
- (trace_collector &&
- trace_collector->IsEnabled(op_kernel->IsExpensive())) ||
- tracing::GetEventCollector(tracing::EventCategory::kCompute))) {
- const string& op_name = op_kernel->name();
- tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
- op_kernel->IsExpensive());
- tracing::ScopedRegion region(tracing::EventCategory::kCompute, op_name);
- op_kernel->Compute(context);
- } else {
- op_kernel->Compute(context);
- }
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(TracingDevice);
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h
deleted file mode 100644
index ae0563a96a..0000000000
--- a/tensorflow/core/common_runtime/visitable_allocator.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-
-#include <functional>
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/tracking_allocator.h"
-
-namespace tensorflow {
-
-// Subclass VisitableAllocator instead of Allocator when a memory
-// allocator needs to enable some kind of registration/deregistration
-// of memory areas.
-class VisitableAllocator : public Allocator {
- public:
- // Visitor gets called with a pointer to a memory area and its
- // size in bytes.
- typedef std::function<void(void*, size_t)> Visitor;
-
- // Register a visitor guaranteed to be called exactly once on each
- // chunk of memory newly allocated from the underlying device.
- // Typically, chunks will be reused and possibly sub-divided by a
- // pool manager, so the calls will happen only once per process
- // execution, not once per tensor (re)allocation.
- virtual void AddAllocVisitor(Visitor visitor) = 0;
-
- // Register a visitor guaranteed to be called on each chunk of
- // memory returned to the underlying device.
- virtual void AddFreeVisitor(Visitor visitor) = 0;
-};
-
-// Needed for cases when a VisitableAllocator gets wrapped for tracking.
-// Multiple-inheritance is considered acceptable in this case because
-// VisitableAllocator is a pure virtual interface and only TrackingAllocator
-// has default implementation.
-class TrackingVisitableAllocator : public TrackingAllocator,
- public VisitableAllocator {
- public:
- TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids)
- : TrackingAllocator(allocator, track_ids), allocator_(allocator) {}
- ~TrackingVisitableAllocator() override {}
-
- string Name() override { return TrackingAllocator::Name(); }
-
- void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- return TrackingAllocator::AllocateRaw(alignment, num_bytes);
- }
-
- void DeallocateRaw(void* ptr) override {
- TrackingAllocator::DeallocateRaw(ptr);
- }
-
- void AddAllocVisitor(Visitor visitor) override {
- allocator_->AddAllocVisitor(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- allocator_->AddFreeVisitor(visitor);
- }
-
- protected:
- VisitableAllocator* allocator_;
-};
-} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 6c146036ae..3361819e43 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -233,14 +233,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
params.function_library = lib;
params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
- // We do not share the kernel via the OpSegment if the node is
- // stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
- if (!lib->IsStateful(ndef.op()) ||
- lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+ if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -252,8 +249,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
- // If the node is stateful, opseg owns it. Otherwise, delete it.
- if (kernel && !lib->IsStateful(kernel->type_string())) {
+ if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
delete kernel;
}
};
@@ -479,10 +475,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
delete step_container;
});
Executor::Args args;
- {
- mutex_lock l(mu_);
- args.step_id = ++next_id_;
- }
+ args.step_id = step_id;
args.rendezvous = rendezvous;
args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
args.cancellation_manager = cancellation_manager;
diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto
index e7142a4ef9..e36e51d8d5 100644
--- a/tensorflow/core/example/example.proto
+++ b/tensorflow/core/example/example.proto
@@ -199,7 +199,13 @@ message Example {
// to determine if all features within the FeatureList must
// have the same size. The same holds for this FeatureList across multiple
// examples.
-//
+// - For sequence modeling, e.g.:
+// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
+// https://github.com/tensorflow/nmt
+// the feature lists represent a sequence of frames.
+// In this scenario, all FeatureLists in a SequenceExample have the same
+// number of Feature messages, so that the ith element in each FeatureList
+// is part of the ith frame (or time step).
// Examples of conformant and non-conformant examples' FeatureLists:
//
// Conformant FeatureLists:
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index ec93b9aad9..016d1a92c1 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -103,6 +103,7 @@ limitations under the License.
#include <iterator>
#include <type_traits>
+#include "absl/base/macros.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -113,10 +114,10 @@ namespace tensorflow {
namespace internal {
-// DEPRECATED: Use GetFeature instead.
// TODO(gorban): Update all clients in a followup CL.
// Returns a reference to a feature corresponding to the name.
// Note: it will create a new Feature if it is missing in the example.
+ABSL_DEPRECATED("Use GetFeature instead.")
Feature& ExampleFeature(const string& name, Example* example);
// Specializations of RepeatedFieldTrait define a type of RepeatedField
@@ -314,9 +315,9 @@ bool HasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, GetFeatures(example));
}
-// DEPRECATED: use HasFeature instead.
// TODO(gorban): update all clients in a followup CL.
template <typename... FeatureType>
+ABSL_DEPRECATED("Use HasFeature instead.")
bool ExampleHasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, example);
}
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 888ed0c57b..84cee5569c 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
@@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
+void Allocator::RunVariantCtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
+}
+
+void Allocator::RunVariantDtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
+}
+
// If true, cpu allocator collects more stats.
static bool cpu_allocator_collect_stats = false;
// If true, cpu allocator collects full stats.
@@ -187,7 +196,7 @@ class CPUAllocatorFactory : public AllocatorFactory {
class CPUSubAllocator : public SubAllocator {
public:
explicit CPUSubAllocator(CPUAllocator* cpu_allocator)
- : cpu_allocator_(cpu_allocator) {}
+ : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {}
void* Alloc(size_t alignment, size_t num_bytes) override {
return cpu_allocator_->AllocateRaw(alignment, num_bytes);
@@ -213,4 +222,22 @@ Allocator* cpu_allocator() {
}
return cpu_alloc;
}
+
+SubAllocator::SubAllocator(const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : alloc_visitors_(alloc_visitors), free_visitors_(free_visitors) {}
+
+void SubAllocator::VisitAlloc(void* ptr, int index, size_t num_bytes) {
+ for (const auto& v : alloc_visitors_) {
+ v(ptr, index, num_bytes);
+ }
+}
+
+void SubAllocator::VisitFree(void* ptr, int index, size_t num_bytes) {
+ // Although we don't guarantee any order of visitor application, strive
+ // to apply free visitors in reverse order of alloc visitors.
+ for (int i = free_visitors_.size() - 1; i >= 0; --i) {
+ free_visitors_[i](ptr, index, num_bytes);
+ }
+}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 774b1fe137..8c23604625 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -23,12 +23,14 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/type_traits.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+class Variant;
+
// Attributes for a single allocation call. Different calls to the same
// allocator could potentially have different allocation attributes.
struct AllocationAttributes {
@@ -228,13 +230,9 @@ class Allocator {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
- virtual void RunVariantCtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
- }
+ virtual void RunVariantCtor(Variant* p, size_t n);
- virtual void RunVariantDtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
- }
+ virtual void RunVariantDtor(Variant* p, size_t n);
// TODO(jeff): Maybe provide some interface to give info about
// current allocation state (total number of bytes available for
@@ -390,13 +388,36 @@ void EnableCPUAllocatorStats(bool enable);
// full statistics. By default, it's disabled.
void EnableCPUAllocatorFullStats(bool enable);
-// Abstract interface of an object that does the underlying suballoc/free of
-// memory for a higher-level allocator.
+// An object that does the underlying suballoc/free of memory for a higher-level
+// allocator. The expectation is that the higher-level allocator is doing some
+// kind of cache or pool management so that it will call SubAllocator::Alloc and
+// Free relatively infrequently, compared to the number of times its own
+// AllocateRaw and Free methods are called.
class SubAllocator {
public:
+ // Visitor gets called with a pointer to a memory area and its
+ // size in bytes. The index value will be numa_node for a CPU
+ // allocator and GPU id for a GPU allocator.
+ typedef std::function<void(void*, int index, size_t)> Visitor;
+
+ SubAllocator(const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors);
+
virtual ~SubAllocator() {}
virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
virtual void Free(void* ptr, size_t num_bytes) = 0;
+
+ protected:
+ // Implementation of Alloc() method must call this on newly allocated
+ // value.
+ void VisitAlloc(void* ptr, int index, size_t num_bytes);
+
+ // Implementation of Free() method must call this on value to be
+ // freed immediately before deallocation.
+ void VisitFree(void* ptr, int index, size_t num_bytes);
+
+ const std::vector<Visitor> alloc_visitors_;
+ const std::vector<Visitor> free_visitors_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h
index 24f282ce84..e907c52ba9 100644
--- a/tensorflow/core/framework/allocator_registry.h
+++ b/tensorflow/core/framework/allocator_registry.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/numa.h"
namespace tensorflow {
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
index 1a3994736c..4ffd732f8e 100644
--- a/tensorflow/core/framework/attr_value_util_test.cc
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
index 1258e40c93..af59500aee 100644
--- a/tensorflow/core/framework/cancellation.cc
+++ b/tensorflow/core/framework/cancellation.cc
@@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) {
}
}
+bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
+ mutex_lock lock(mu_);
+ if (is_cancelled_ || is_cancelling_) {
+ return false;
+ } else {
+ callbacks_.erase(token);
+ return true;
+ }
+}
+
CancellationManager::~CancellationManager() {
if (!callbacks_.empty()) {
StartCancel();
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index acdaaf6a90..7a5d942486 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -122,6 +122,15 @@ class CancellationManager {
// cancellation manager.
bool DeregisterCallback(CancellationToken token);
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // immediately, with no guarantee that the callback has completed.
+ //
+ // This method is guaranteed to return true if StartCancel has not been
+ // called.
+ bool TryDeregisterCallback(CancellationToken token);
+
private:
bool is_cancelling_;
std::atomic_bool is_cancelled_;
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
index e3f18240b5..bf7593bc5f 100644
--- a/tensorflow/core/framework/cancellation_test.cc
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) {
delete cm;
}
+TEST(Cancellation, TryDeregisterWithoutCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_TRUE(deregistered);
+ delete manager;
+ EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, TryDeregisterAfterCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+ delete manager;
+}
+
+TEST(Cancellation, TryDeregisterDuringCancel) {
+ Notification cancel_started, finish_callback, cancel_complete;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(token, [&]() {
+ cancel_started.Notify();
+ finish_callback.WaitForNotification();
+ });
+ EXPECT_TRUE(registered);
+
+ thread::ThreadPool w(Env::Default(), "test", 1);
+ w.Schedule([&]() {
+ manager->StartCancel();
+ cancel_complete.Notify();
+ });
+ cancel_started.WaitForNotification();
+
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+
+ finish_callback.Notify();
+ cancel_complete.WaitForNotification();
+ delete manager;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index 5281c56f04..284dafb886 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -20,7 +20,6 @@ limitations under the License.
namespace tensorflow {
namespace data {
-
namespace {
// A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 4e51fba048..697e0604bf 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -46,6 +47,8 @@ class GraphDefBuilder;
class Node;
namespace data {
+// A constant that can be used to enable auto-tuning.
+constexpr int kAutoTune = -1;
class DatasetBase;
class SerializationContext;
@@ -291,6 +294,9 @@ class IteratorContext {
// The Allocator to be used to allocate the output of an iterator.
std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
+
+ // If non-null, identifies the object used for performance modeling.
+ std::shared_ptr<model::Model> model = nullptr;
};
explicit IteratorContext(Params params) : params_(std::move(params)) {}
@@ -342,6 +348,10 @@ class IteratorContext {
return params_.stats_aggregator_getter;
}
+ std::shared_ptr<model::Model> model() { return params_.model; }
+
+ Params params() { return params_; }
+
private:
Params params_;
};
@@ -376,7 +386,11 @@ class SerializationContext {
// defined below.
class IteratorBase {
public:
- virtual ~IteratorBase() {}
+ virtual ~IteratorBase() {
+ for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
+ (*rit)();
+ }
+ }
// Gets the next output from the range that this iterator is traversing.
//
@@ -410,6 +424,10 @@ class IteratorBase {
// in the outputs of this iterator.
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+ // Returns a string that identifies the sequence of iterators leading up to
+ // this iterator.
+ virtual const string& prefix() const = 0;
+
// Performs initialization that needs to happen outside of a constructor to
// properly propagate errors.
virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
@@ -449,6 +467,18 @@ class IteratorBase {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
+
+ private:
+ friend class DatasetBase; // for access to `AddCleanupFunction`
+
+ // Registers a cleanup function to be called upon object destruction.
+ //
+ // Registered functions are invoked in the reserve order of registration.
+ void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
+ cleanup_fns_.push_back(std::move(cleanup_fn));
+ }
+
+ std::vector<std::function<void()>> cleanup_fns_;
};
// Represents runtime information needed to construct a dataset.
@@ -498,6 +528,13 @@ class DatasetBase : public core::RefCounted {
Status MakeIterator(IteratorContext* ctx, const string& prefix,
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
+ if (ctx->model()) {
+ ctx->model()->AddNode((*iterator)->prefix(), prefix);
+ std::shared_ptr<model::Model> model = ctx->model();
+ const string& prefix = (*iterator)->prefix();
+ (*iterator)->AddCleanupFunction(
+ [model, prefix]() { model->RemoveNode(prefix); });
+ }
return (*iterator)->Initialize(ctx);
}
@@ -524,6 +561,8 @@ class DatasetBase : public core::RefCounted {
IteratorStateWriter* writer) const;
protected:
+ friend class DatasetToGraphOp; // For access to graph related members.
+
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
@@ -541,8 +580,6 @@ class DatasetBase : public core::RefCounted {
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
- friend class DatasetToGraphOp; // For access to graph related members.
-
private:
const string name_;
};
@@ -565,7 +602,7 @@ class DatasetBaseIterator : public IteratorBase {
~DatasetBaseIterator() override { params_.dataset->Unref(); }
// The sequence of iterators leading up to this iterator.
- const string& prefix() const { return params_.prefix; }
+ const string& prefix() const override { return params_.prefix; }
const DataTypeVector& output_dtypes() const override {
return params_.dataset->output_dtypes();
@@ -578,7 +615,10 @@ class DatasetBaseIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
+ RecordStart(ctx, true /* stop_output */);
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ if (s.ok() && !*end_of_sequence) RecordElement(ctx);
+ RecordStop(ctx, true /* start_output */);
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
@@ -605,6 +645,54 @@ class DatasetBaseIterator : public IteratorBase {
return strings::StrCat(params_.prefix, ":", name);
}
+ // When performance modeling is enabled, this method adds a constant parameter
+ // to the model node corresponding to this iterator.
+ void AddConstantParameter(IteratorContext* ctx, const string& name,
+ int64 value) {
+ if (ctx->model()) {
+ ctx->model()->AddConstantParameter(prefix(), name, value);
+ }
+ }
+
+ // When performance modeling is enabled, this method adds a tunable parameter
+ // to the model node corresponding to this iterator.
+ //
+ // The performance modeling logic may use `value` to set the value of the
+ // tunable parameter at any point during the lifetime of this iterator. When
+ // it does, it notifies `cond_var`.
+ void AddTunableParameter(IteratorContext* ctx, const string& name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
+ if (ctx->model()) {
+ ctx->model()->AddTunableParameter(prefix(), name, value, min, max,
+ cond_var);
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // this iterator has produced an element.
+ void RecordElement(IteratorContext* ctx) {
+ if (ctx->model()) {
+ ctx->model()->RecordElement(prefix());
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has started work.
+ void RecordStart(IteratorContext* ctx, bool stop_output = false) {
+ if (ctx->model()) {
+ ctx->model()->RecordStart(prefix(), stop_output);
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has stopped work.
+ void RecordStop(IteratorContext* ctx, bool start_output = false) {
+ if (ctx->model()) {
+ ctx->model()->RecordStop(prefix(), start_output);
+ }
+ }
+
private:
BaseParams params_;
};
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 794250a2c1..446c31b17f 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -176,9 +177,9 @@ class DeviceBase {
return nullptr;
}
- // DEPRECATED: Use `this->GetAllocator()` or `this->GetScopedAllocator()`.
// This method is provided for backwards compatibility, and will be removed
// in a future release.
+ ABSL_DEPRECATED("Use `this->GetAllocator()` or `this->GetScopedAllocator()`.")
Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) {
return GetAllocator(attr);
}
@@ -214,10 +215,12 @@ class DeviceBase {
// This is overridden by GPU devices to reinitialize the derived
// type returned by MakeGpuDevice.
- virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/,
- PerOpGpuDevice* /*device*/,
- DeviceContext* /*dc*/,
- Allocator* /*allocator*/) {}
+ virtual Status ReinitializeGpuDevice(OpKernelContext* /*context*/,
+ PerOpGpuDevice* /*device*/,
+ DeviceContext* /*dc*/,
+ Allocator* /*allocator*/) {
+ return Status::OK();
+ }
// Unimplemented by default
virtual const DeviceAttributes& attributes() const;
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 26f32677af..a17959a448 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1154,6 +1154,17 @@ Status FunctionLibraryDefinition::LookUp(
return default_registry_->LookUp(op, op_reg_data);
}
+string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
+ tf_shared_lock l(mu_);
+ int index = 0;
+ string name = strings::StrCat(prefix, index);
+ while (function_defs_.find(name) != function_defs_.end()) {
+ ++index;
+ name = strings::StrCat(prefix, index);
+ }
+ return name;
+}
+
const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
const NodeDef& ndef) const {
if (ndef.op() != kGradientOp) {
@@ -1283,6 +1294,18 @@ FunctionDef FunctionDefHelper::Create(
for (const auto& r : ret_def) {
fdef.mutable_ret()->insert({r.first, r.second});
}
+
+ auto* op_def_registry = OpRegistry::Global();
+ // Check if any op is stateful.
+ for (const auto& n : node_def) {
+ const OpDef* op_def = nullptr;
+ auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
+ // Lookup can fail if e.g. we are calling a function that was not yet
+ // defined. If it happens, conservatively assume the op is stateful.
+ if (!status.ok() || op_def->is_stateful()) {
+ fdef.mutable_signature()->set_is_stateful(true);
+ }
+ }
return fdef;
}
@@ -1344,6 +1367,7 @@ FunctionDef FunctionDefHelper::Define(const string& name,
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
}
}
+ if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
}
// Returns
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 03296a7761..e01eb7503d 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -358,6 +358,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const OpRegistrationData** op_reg_data) const override
LOCKS_EXCLUDED(mu_);
+ // Generates new function name with the specified prefix that is unique
+ // across this library.
+ string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_);
+
// Ops created for function arguments bear the name given by `kArgOp`; those
// created for return values bear the name given by `kRetOp`.
static constexpr const char* const kArgOp = "_Arg";
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index c5a4f661d2..d5c203d276 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -91,6 +91,40 @@ FunctionDef IsZero() {
});
}
+FunctionDef RandomUniform() {
+ const Tensor kZero = test::AsScalar<int64>(0);
+ const Tensor kTen = test::AsScalar<int64>(10);
+
+ return FDH::Define(
+ // Name
+ "RandomUniform",
+ // Args
+ {"x: T"},
+ // Return values
+ {"random_uniform: int64"},
+ // Attr def
+ {"T:{float, double, int32, int64, string}"},
+ {{{"random_uniform/shape"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform/min"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform/max"},
+ "Const",
+ {},
+ {{"value", kTen}, {"dtype", DT_INT64}}},
+ {{"random_uniform"},
+ "RandomUniformInt",
+ {},
+ {{"T", DT_INT64},
+ {"Tout", DT_INT64},
+ {"seed", 87654321},
+ {"seed2", 42}}}});
+}
+
FunctionDef XTimesTwo() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index ad61a76f16..a01743423b 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -84,6 +84,9 @@ FunctionDef NonZero();
// x: T -> bool.
FunctionDef IsZero();
+// x: T -> int64
+FunctionDef RandomUniform();
+
// x:T, y:T -> y:T, x:T
FunctionDef Swap();
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
new file mode 100644
index 0000000000..b0330ec990
--- /dev/null
+++ b/tensorflow/core/framework/model.cc
@@ -0,0 +1,419 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/model.h"
+
+#include <memory>
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
+void Model::Node::CollectTunables(
+ std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
+ tf_shared_lock l(mu_);
+ for (auto input : inputs_) {
+ input->CollectTunables(tunables);
+ }
+ switch (type_) {
+ case Type::MAP_AND_BATCH:
+ case Type::PARALLEL_INTERLEAVE_V2:
+ case Type::PARALLEL_MAP: {
+ if (auto* tunable_param =
+ gtl::FindOrNull(tunable_params_, "parallelism")) {
+ tunables->push_back(*tunable_param);
+ }
+ return;
+ }
+ default:
+ return;
+ }
+}
+
+int64 Model::Node::GetParameterValue(const string& name) {
+ if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
+ return (*tunable_param)->value;
+ }
+ return constant_params_[name];
+}
+
+int64 Model::Node::ProcessingTimeLocked() {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::MAP_AND_BATCH:
+ case Type::PADDED_BATCH: {
+ int64 batch_size = GetParameterValue("batch_size");
+ return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ return NanosPerElementLocked() +
+ static_cast<int64>(ratio *
+ static_cast<double>(ProcessingTimeForInputs()));
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 processing_time =
+ ProcessingTimeForInputs() - inputs_.front()->ProcessingTime();
+ return NanosPerElementLocked() +
+ static_cast<double>(processing_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::PARALLEL_MAP:
+ case Type::PREFETCH:
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ return NanosPerElementLocked() + ProcessingTimeForInputs();
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::PADDED_BATCH: {
+ double batch_size = GetParameterValue("batch_size");
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) /
+ batch_size);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ batch_size * OutputTimeForInputs(input_times);
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) / ratio);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ static_cast<int64>(
+ static_cast<double>(OutputTimeForInputs(input_times)) * ratio);
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1`
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1));
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ int64 output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ return NanosPerElementLocked() +
+ static_cast<double>(output_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::MAP_AND_BATCH: {
+ double batch_size = GetParameterValue("batch_size");
+ double parallelism = GetParameterValue("parallelism");
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
+ (batch_size * parallelism));
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ batch_size * OutputTimeForInputs(input_times));
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism = GetParameterValue("parallelism");
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism =
+ std::min(static_cast<int>(GetParameterValue("cycle_length")),
+ static_cast<int>(GetParameterValue("parallelism")));
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_MAP: {
+ double parallelism =
+ std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(GetParameterValue("parallelism")));
+ int64 delta = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time =
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ OutputTimeForInputs(input_times);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PREFETCH: {
+ int64 delta = NanosPerElementLocked();
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ return std::max(0LL, NanosPerElementLocked() +
+ OutputTimeForInputs(input_times) -
+ input_times->at(input_times->size() - 2));
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ int64 delta = NanosPerElementLocked();
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ return NanosPerElementLocked() + OutputTimeForInputs(input_times);
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+void Model::AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, node_name);
+ if (node) {
+ (*node)->add_constant_param(parameter_name, value);
+ }
+}
+
+void Model::AddNode(const string& name, const string& output_name) {
+ // The name captures the sequence of iterators joined by `::`. We use the full
+ // sequence as the key in the lookup table, but only the last element of the
+ // sequence as the name node.
+ std::vector<string> tokens =
+ str_util::Split(name, ':', str_util::SkipEmpty());
+ // The output name might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_output_name = output_name;
+ if (str_util::EndsWith(output_name, "]")) {
+ sanitized_output_name = output_name.substr(0, output_name.rfind('['));
+ }
+ std::shared_ptr<Node> output;
+ mutex_lock l(mu_);
+ auto it = lookup_table_.find(sanitized_output_name);
+ if (it != lookup_table_.end()) {
+ output = it->second;
+ }
+ std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output));
+ if (!output_) {
+ output_ = node;
+ }
+ if (output) {
+ output->add_input(node);
+ }
+ lookup_table_.insert(std::make_pair(name, node));
+}
+
+void Model::AddProcessingTime(const string& name, int64 delta) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->add_processing_time(delta);
+ }
+}
+
+void Model::AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) {
+ tf_shared_lock l(mu_);
+ auto node = *gtl::FindOrNull(lookup_table_, node_name);
+ DCHECK(node);
+ node->add_tunable_param(parameter_name, value, min, max, cond_var);
+}
+
+// The optimization algorithm starts by setting all tunable parallelism
+// parameters to 1. It then repeatedly identifies the parameter whose increase
+// in parallelism decreases the output time the most. This process is repeated
+// until all parameters reach their maximum values or the projected output time
+// is less than or equal to the processing time needed to produce an element
+// divided by CPU budget.
+void Model::Optimize(int64 cpu_budget) {
+ tf_shared_lock lock(mu_);
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
+ }
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
+ break;
+ }
+ int64 best_delta = -1;
+ Model::Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
+ }
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
+ }
+ tunable->value--;
+ }
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
+ }
+ best_tunable->value++;
+ }
+ VLOG(2) << "Number of knobs: " << tunables.size();
+ for (auto& tunable : tunables) {
+ VLOG(2) << "Setting tunable parameter: " << tunable->value;
+ tunable->value_ptr->store(tunable->value);
+ if (tunable->cond_var) {
+ tunable->cond_var->notify_all();
+ }
+ }
+}
+
+void Model::RecordElement(const string& name) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_element();
+ }
+}
+
+void Model::RecordStart(const string& name, bool stop_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ if (stop_output && (*node)->output()) {
+ (*node)->output()->record_stop();
+ }
+ (*node)->record_start();
+ }
+}
+
+void Model::RecordStop(const string& name, bool start_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_stop();
+ if (start_output && (*node)->output()) {
+ (*node)->output()->record_start();
+ }
+ }
+}
+
+void Model::RemoveNode(const string& name) {
+ mutex_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node && (*node)->output()) {
+ (*node)->output()->remove_input(*node);
+ }
+ lookup_table_.erase(name);
+}
+
+std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() {
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ output_->CollectTunables(&tunables);
+ return tunables;
+}
+
+int64 Model::OutputTime() {
+ std::vector<int64> input_times(1, 0);
+ return output_->OutputTime(&input_times);
+}
+
+int64 Model::ProcessingTime() { return output_->ProcessingTime(); }
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
new file mode 100644
index 0000000000..26402f5cd3
--- /dev/null
+++ b/tensorflow/core/framework/model.h
@@ -0,0 +1,404 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <thread> // (b/114492873): move this include into core/platform
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+// Abstract representation of a TensorFlow input pipeline that can be used
+// for collecting runtime information and optimizing performance. It collects
+// runtime information about execution of the input pipeline that is used to
+// create a performance model, which is in turn used to identify optimal values
+// of tunable parameters.
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting runtime information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+class Model {
+ public:
+ Model() = default;
+
+ // Adds a constant parameter for the given node.
+ void AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value)
+ LOCKS_EXCLUDED(mu_);
+
+ // Adds a node with the given name and given output (identified by name).
+ void AddNode(const string& name, const string& output_name)
+ LOCKS_EXCLUDED(mu_);
+
+ // Increments the processing time for the given node..
+ void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_);
+
+ // Adds a tunable parameter for the given node.
+ void AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var) LOCKS_EXCLUDED(mu_);
+
+ // Runs optimization.
+ void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
+
+ // Records that a node has produced an element.
+ void RecordElement(const string& name) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has started work. If `stop_output` is set, it
+ // also records that the output of the given node has stopped work.
+ void RecordStart(const string& name, bool stop_output) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has stopped work. If `stop_output` is set, it
+ // also records that the output of the given node has started work.
+ void RecordStop(const string& name, bool start_output) LOCKS_EXCLUDED(mu_);
+
+ // Removes the given node.
+ void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);
+
+ private:
+ // Abstract representation of a TensorFlow input pipeline node. It collects
+ // information about inputs to this node, processing time spent executing the
+ // node logic, number of elements produced by the node, various other
+ // information (e.g. batch size or execution parallelism).
+ //
+ // Developers of tf.data transformations are not expected to interact with
+ // this class directly. Boiler plate code for creating the abstract
+ // representation of the input pipeline and collecting common information has
+ // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
+ // respectively.
+ //
+ // In addition, `DatasetBaseIterator` provides wrappers that can be used for
+ // transformation-specific information collection. The `SetMetadata` wrapper
+ // can be used to pass arbitrary metadata to the modeling framework, while the
+ // `StartWork` and `StopWork` wrappers should be used to correctly account for
+ // processing time of multi-threaded transformation that yield the CPU; such
+ // transformations should invoke `StartWork()` when a transformation thread
+ // starts executing (e.g. when created or woken up) and `StopWork()` when a
+ // transformation thread stops executing (e.g. when returning or waiting).
+ //
+ // TODO(jsimsa): Create an API to capture the abstract semantics of each
+ // tf.data transformation and replace switch-case blocks with inheritance.
+ class Node {
+ public:
+ // Represents a tunable parameter.
+ struct Tunable {
+ Tunable(std::atomic<int64>* value, int64 min, int64 max,
+ condition_variable* cond_var)
+ : value(*value),
+ min(min),
+ max(max),
+ value_ptr(value),
+ cond_var(cond_var) {}
+
+ // Identifies the model value of the parameter. This can be different from
+ // the actual value (e.g. during optimization search).
+ int64 value;
+
+ // Identifies the minimum value of the parameter.
+ int64 min;
+
+ // Identifies the maximum value of the parameter.
+ int64 max;
+
+ // Points to the actual value of the parameter. Not owned.
+ std::atomic<int64>* value_ptr;
+
+ // If non-null, this condition variable is notified when the model updates
+ // the actual value of the parameter (via `value_ptr`). Not owned.
+ condition_variable* cond_var;
+ };
+
+ Node(int64 id, const string& name, std::shared_ptr<Node> output)
+ : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {}
+
+ // Adds a constant parameter.
+ void add_constant_param(const string& name, int64 value)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ constant_params_[name] = value;
+ }
+
+ // Adds an input.
+ void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.push_back(node);
+ }
+
+ // Increments the aggregate processing time by the given delta.
+ void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ // Adds a tunable parameter.
+ void add_tunable_param(const string& name, std::atomic<int64>* value,
+ int64 min, int64 max, condition_variable* cond_var)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ tunable_params_[name] =
+ std::make_shared<Tunable>(value, min, max, cond_var);
+ }
+
+ // Returns the unique node ID.
+ int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+ // Returns the node inputs.
+ std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return inputs_;
+ }
+
+ // Returns the node name.
+ const string& name() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return name_;
+ }
+
+ // Returns the number of elements produced by the node.
+ int64 num_elements() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return num_elements_;
+ }
+
+ // Returns the node output.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Records that the node produced an element.
+ void record_element() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ num_elements_++;
+ }
+
+ // Records that a node thread has started executing.
+ void record_start() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
+ }
+
+ // Records that a node thread has stopped executing.
+ void record_stop() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ std::thread::id tid = std::this_thread::get_id();
+ auto start_time = gtl::FindOrNull(work_start_, tid);
+ DCHECK(start_time)
+ << "Encountered a stop event that was not preceded by a start event.";
+ if (start_time) {
+ processing_time_ += Env::Default()->NowNanos() - *start_time;
+ work_start_.erase(tid);
+ }
+ }
+
+ // Removes an input.
+ void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.remove(input);
+ }
+
+ // Set the node output.
+ void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ output_ = output;
+ }
+
+ // Collects tunable parameters in the subtree rooted in this node.
+ void CollectTunables(std::vector<std::shared_ptr<Tunable>>* tunables)
+ LOCKS_EXCLUDED(mu_);
+
+ // Returns the per-element output time for this node.
+ int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return OutputTimeLocked(input_times);
+ }
+
+ // Returns the per-element processing time spent in the subtree rooted in
+ // this node.
+ int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return ProcessingTimeLocked();
+ }
+
+ private:
+ enum class Type {
+ BATCH = 0,
+ CACHE,
+ CONCATENATE,
+ FILTER,
+ FLAT_MAP,
+ INTERLEAVE,
+ MAP,
+ MAP_AND_BATCH,
+ PADDED_BATCH,
+ PARALLEL_INTERLEAVE,
+ PARALLEL_INTERLEAVE_V2,
+ PARALLEL_MAP,
+ PREFETCH,
+ REPEAT,
+ SHUFFLE,
+ SKIP,
+ TAKE,
+ ZIP,
+ UNKNOWN,
+ };
+
+ // Gets a value of the given parameter (tunable or constant).
+ int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in this node.
+ int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return NanosPerElementLocked();
+ }
+
+ int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) {
+ if (num_elements_ == 0) {
+ return 0;
+ }
+ return (int64)((double)processing_time_ / (double)num_elements_);
+ }
+
+ int64 OutputTimeLocked(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTimeForInputs(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->OutputTime(input_times);
+ }
+ return sum;
+ }
+
+ int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in the inputs of this node.
+ int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->ProcessingTime();
+ }
+ return sum;
+ }
+
+ Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) {
+ if (name_ == "Batch") {
+ return Type::BATCH;
+ }
+ if (str_util::EndsWith(name_, "Cache")) {
+ return Type::CACHE;
+ }
+ if (name_ == "Concatenate") {
+ return Type::CONCATENATE;
+ }
+ if (name_ == "Filter") {
+ return Type::FILTER;
+ }
+ if (name_ == "FlatMap") {
+ return Type::FLAT_MAP;
+ }
+ if (name_ == "Interleave") {
+ return Type::INTERLEAVE;
+ }
+ if (name_ == "Map") {
+ return Type::MAP;
+ }
+ if (name_ == "MapAndBatch") {
+ return Type::MAP_AND_BATCH;
+ }
+ if (name_ == "PaddedBatch") {
+ return Type::PADDED_BATCH;
+ }
+ if (name_ == "ParallelInterleave") {
+ return Type::PARALLEL_INTERLEAVE;
+ }
+ if (name_ == "ParallelInterleaveV2") {
+ return Type::PARALLEL_INTERLEAVE_V2;
+ }
+ if (name_ == "ParallelMap") {
+ return Type::PARALLEL_MAP;
+ }
+ if (name_ == "Prefetch") {
+ return Type::PREFETCH;
+ }
+ if (str_util::EndsWith(name_, "Repeat")) {
+ return Type::REPEAT;
+ }
+ if (name_ == "Shuffle") {
+ return Type::SHUFFLE;
+ }
+ if (str_util::EndsWith(name_, "Skip")) {
+ return Type::SKIP;
+ }
+ if (str_util::EndsWith(name_, "Take")) {
+ return Type::TAKE;
+ }
+ if (name_ == "Zip") {
+ return Type::ZIP;
+ }
+ return Type::UNKNOWN;
+ }
+
+ mutex mu_;
+ const int64 id_;
+ const string name_;
+ const Type type_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+ int64 num_elements_ GUARDED_BY(mu_) = 0;
+ std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+ std::map<string, int64> constant_params_ GUARDED_BY(mu_);
+ // Tunables are shared with the model during optimization.
+ std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
+ std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ };
+
+ std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
+ SHARED_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTime() SHARED_LOCKS_REQUIRED(mu_);
+
+ int64 ProcessingTime() SHARED_LOCKS_REQUIRED(mu_);
+
+ // Used for coordination between different input pipeline threads. Exclusive
+ // access is required only when adding or removing nodes. Concurrent access to
+ // existing nodes is protected by a node mutex.
+ mutex mu_;
+ int64 id_counter_ GUARDED_BY(mu_) = 1;
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
+};
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index bacc1d72c4..43ac1d0ada 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -372,6 +372,14 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
node_def.name());
}
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ return Status::OK();
+}
+
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int output_port, DataType* output_type) {
DataTypeVector output_types;
@@ -397,12 +405,18 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs) {
- for (const auto& arg : op_def.input_arg()) {
- TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
- }
+ TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
return OutputTypesForNode(node_def, op_def, outputs);
}
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+ int* num_outputs) {
+ DataTypeVector outputs;
+ TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
+ *num_outputs = outputs.size();
+ return Status::OK();
+}
+
Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
if (node_def.op() != op_def.name()) {
return errors::InvalidArgument("NodeDef op '", node_def.op(),
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 499034cab2..187bfa2c88 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -249,6 +249,10 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
// REQUIRES: ValidateOpDef(op_def).ok()
Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int input_port, DataType* input_type);
+// Computes the input types for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs);
// Computes the output type for a specific node output.
// REQUIRES: ValidateOpDef(op_def).ok()
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
@@ -261,6 +265,10 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs);
+// Computes the number of outputs for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+ int* num_outputs);
// Validates that the NodeDef:
// * Defines all expected attrs from the OpDef.
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 74cc594863..d9d437024a 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) {
"Illegal op input name 'a:00");
}
+TEST(InputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_FLOAT);
+ EXPECT_EQ(types[1], DT_INT32);
+
+ DataType type;
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_FLOAT);
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_INT32);
+ EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
+TEST(OutputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_STRING);
+ EXPECT_EQ(types[1], DT_BOOL);
+
+ DataType type;
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_STRING);
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_BOOL);
+ EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
TEST(NameRangesForNodeTest, Simple) {
const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
.Input("a: float")
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index c694e10193..3e34bf0418 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
// OpKernel ------------------------------------------------------------------
-// TODO(mrry): Convert to std::make_unique when available.
OpKernel::OpKernel(OpKernelConstruction* context)
- : OpKernel(context,
- std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
+ : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
OpKernel::OpKernel(OpKernelConstruction* context,
std::unique_ptr<const NodeDef> node_def)
@@ -266,9 +265,12 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
params_->ensure_eigen_gpu_device();
if (params_->eigen_gpu_device != nullptr) {
Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
- params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
- params_->op_device_context,
- eigen_gpu_allocator);
+ Status s = params_->device->ReinitializeGpuDevice(
+ this, params_->eigen_gpu_device, params_->op_device_context,
+ eigen_gpu_allocator);
+ if (!s.ok()) {
+ SetStatus(s);
+ }
}
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
@@ -525,10 +527,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input(
return nullptr;
}
}
- // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
- // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
- // general cleanup of ownership in this code.
- std::unique_ptr<Tensor> output_tensor(new Tensor());
+
+ auto output_tensor = MakeUnique<Tensor>();
CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
return output_tensor;
}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index e752599de1..4bbd6c3d7d 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -372,18 +372,37 @@ class OpKernelConstruction {
template <typename ListType, typename ElementType>
class OpArgIterator {
public:
- typedef OpArgIterator<ListType, ElementType> ME;
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = ElementType;
+ using pointer = ElementType*;
+ using reference = ElementType&;
+ using difference_type = ptrdiff_t;
+
OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
- bool operator==(const ME& rhs) {
+
+ bool operator==(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ == rhs.i_;
}
- bool operator!=(const ME& rhs) {
+
+ bool operator!=(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ != rhs.i_;
}
- void operator++() { ++i_; }
- ElementType& operator*() { return (*list_)[i_]; }
+
+ OpArgIterator operator++() { // prefix ++it
+ ++i_;
+ return *this;
+ }
+
+ OpArgIterator operator++(int) { // postfix it++
+ OpArgIterator old_value = *this;
+ ++i_;
+ return old_value;
+ }
+
+ reference operator*() { return (*list_)[i_]; }
+ pointer operator->() { return &(*list_)[i_]; }
private:
const ListType* const list_;
@@ -394,7 +413,7 @@ class OpArgIterator {
// that are passed to the op as a single named argument.
class OpInputList {
public:
- typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
+ typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext* ctx, int start, int stop)
: ctx_(ctx), start_(start), stop_(stop) {}
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
index dfc5aa7747..75ed4a4eaf 100644
--- a/tensorflow/core/framework/op_segment.cc
+++ b/tensorflow/core/framework/op_segment.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) {
delete item;
}
+bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op) {
+ // OpSegment should not own kernel if the node is stateless, or a function.
+ return lib->IsStateful(node_op) &&
+ lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr;
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h
index 4433a2554f..37d939ea2b 100644
--- a/tensorflow/core/framework/op_segment.h
+++ b/tensorflow/core/framework/op_segment.h
@@ -60,6 +60,10 @@ class OpSegment {
Status FindOrCreate(const string& session_handle, const string& node_name,
OpKernel** kernel, CreateKernelFn create_fn);
+ // Returns true if OpSegment should own the kernel.
+ static bool ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op);
+
private:
// op name -> OpKernel
typedef std::unordered_map<string, OpKernel*> KernelMap;
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 0a19861efd..ebdaaec153 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -271,7 +271,7 @@ string ContainerInfo::DebugString() const {
"]");
}
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) {
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) {
return ctx->input(input).flat<ResourceHandle>()(0);
}
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index f8a587c9b5..d58deaa3fc 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -79,7 +79,7 @@ class ResourceBase : public core::RefCounted {
virtual string DebugString() = 0;
// Returns memory used by this resource.
- virtual int64 MemoryUsed() const { return 0; };
+ virtual int64 MemoryUsed() const { return 0; }
};
// Container used for per-step resources.
@@ -234,7 +234,7 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
const string& name);
// Returns a resource handle from a numbered op input.
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
@@ -348,6 +348,8 @@ class ResourceHandleOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
private:
string container_;
string name_;
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 516afa517d..3df677675e 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -812,6 +812,28 @@ Tensor Tensor::Slice(int64 start, int64 limit) const {
return ret;
}
+Tensor Tensor::SubSlice(int64 index) const {
+ CHECK_GE(dims(), 2); // Crash ok.
+ CHECK_LE(0, index); // Crash ok.
+ int64 dim0_size = shape_.dim_size(0);
+ CHECK_LE(index, dim0_size); // Crash ok.
+ Tensor ret;
+ ret.shape_ = shape_;
+ ret.shape_.RemoveDim(0);
+ ret.set_dtype(dtype());
+ ret.buf_ = nullptr;
+ if (dim0_size > 0) {
+ const int64 elems_per_dim0 = NumElements() / dim0_size;
+ const int64 delta = index * elems_per_dim0;
+ const int64 num_elems = elems_per_dim0;
+ if (buf_) {
+ DataType dt = dtype();
+ CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
+ }
+ }
+ return ret;
+}
+
bool Tensor::FromProto(const TensorProto& proto) {
return FromProto(cpu_allocator(), proto);
}
@@ -948,9 +970,69 @@ void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
}
}
+// Appends the spacing between elements for a given dim onto a result string
+void PrintDimSpacing(int dim_index, int num_dims, string* result) {
+ if (dim_index == num_dims - 1) {
+ strings::StrAppend(result, " ");
+ return;
+ }
+ for (int j = 0; j < num_dims - dim_index - 1; j++) {
+ strings::StrAppend(result, "\n");
+ }
+ for (int j = 0; j <= dim_index; j++) {
+ strings::StrAppend(result, " ");
+ }
+}
+
+// Print from left dim to right dim recursively.
+template <typename T>
+void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
+ int64 num_elts_at_ends, int num_dims, const T* data,
+ int64 data_index, string* result) {
+ // We have recursed beyond all the dimensions into a single element
+ // of the tensor.
+ if (dim_index == num_dims) {
+ strings::StrAppend(result, PrintOneElement(data[data_index]));
+ return;
+ }
+
+ strings::StrAppend(result, "[");
+ int64 element_count = shape[dim_index];
+ int64 start_of_end =
+ std::max(num_elts_at_ends, element_count - num_elts_at_ends);
+
+ // Loop every element of one dim.
+ int64 elements_per_iter = 1;
+ for (int i = dim_index + 1; i < num_dims; i++) {
+ elements_per_iter *= shape[i];
+ }
+ for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
+ if (i > 0) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ }
+
+ // As for each element, print the sub-dim.
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+ if (element_count > 2 * num_elts_at_ends) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ strings::StrAppend(result, "...");
+ }
+ for (int64 i = start_of_end; i < element_count; i++) {
+ // As for each element, print the sub-dim.
+ PrintDimSpacing(dim_index, num_dims, result);
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+
+ strings::StrAppend(result, "]");
+}
+
template <typename T>
string SummarizeArray(int64 limit, int64 num_elts,
- const TensorShape& tensor_shape, const char* data) {
+ const TensorShape& tensor_shape, const char* data,
+ const bool print_v2) {
string ret;
const T* array = reinterpret_cast<const T*>(data);
@@ -963,17 +1045,26 @@ string SummarizeArray(int64 limit, int64 num_elts,
if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
- int64 data_index = 0;
- const int shape_size = tensor_shape.dims();
- PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+ if (print_v2) {
+ const int num_dims = tensor_shape.dims();
+ PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
+ } else {
+ int64 data_index = 0;
+ const int shape_size = tensor_shape.dims();
+ PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+
+ if (num_elts > limit) strings::StrAppend(&ret, "...");
+ }
- if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
} // namespace
-string Tensor::SummarizeValue(int64 max_entries) const {
+string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
const int64 num_elts = NumElements();
+ if (max_entries < 0) {
+ max_entries = num_elts;
+ }
size_t limit = std::min(max_entries, num_elts);
if ((limit > 0) && (buf_ == nullptr)) {
return strings::StrCat("uninitialized Tensor of ", num_elts,
@@ -982,50 +1073,54 @@ string Tensor::SummarizeValue(int64 max_entries) const {
const char* data = limit > 0 ? tensor_data().data() : nullptr;
switch (dtype()) {
case DT_HALF:
- return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
+ return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
+ print_v2);
break;
case DT_FLOAT:
- return SummarizeArray<float>(limit, num_elts, shape_, data);
+ return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
break;
case DT_DOUBLE:
- return SummarizeArray<double>(limit, num_elts, shape_, data);
+ return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT32:
- return SummarizeArray<uint32>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT32:
- return SummarizeArray<int32>(limit, num_elts, shape_, data);
+ return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT8:
case DT_QUINT8:
- return SummarizeArray<uint8>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT16:
case DT_QUINT16:
- return SummarizeArray<uint16>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT16:
case DT_QINT16:
- return SummarizeArray<int16>(limit, num_elts, shape_, data);
+ return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT8:
case DT_QINT8:
- return SummarizeArray<int8>(limit, num_elts, shape_, data);
+ return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT64:
- return SummarizeArray<uint64>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT64:
- return SummarizeArray<int64>(limit, num_elts, shape_, data);
+ return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_BOOL:
// TODO(tucker): Is it better to emit "True False..."? This
// will emit "1 0..." which is more compact.
- return SummarizeArray<bool>(limit, num_elts, shape_, data);
+ return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
break;
default: {
// All irregular cases
string ret;
+ if (print_v2) {
+ strings::StrAppend(&ret, "[");
+ }
// TODO(irving): Don't call flat every time around this
// loop.
for (size_t i = 0; i < limit; ++i) {
@@ -1045,6 +1140,9 @@ string Tensor::SummarizeValue(int64 max_entries) const {
}
}
if (max_entries < num_elts) strings::StrAppend(&ret, "...");
+ if (print_v2) {
+ strings::StrAppend(&ret, "]");
+ }
return ret;
}
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 1b19ab5da3..8a0c70fef2 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -37,11 +37,12 @@ namespace tensorflow {
class AllocationDescription;
class Allocator;
class OpKernelContext;
+class Tensor;
class TensorBuffer;
class TensorCApi;
class TensorDescription;
class TensorProto;
-class VariantTensorData;
+
namespace batch_util {
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
@@ -153,7 +154,7 @@ class Tensor {
/// Returns the estimated memory usage of this tensor.
size_t TotalBytes() const;
- // Returns the size of sallocated memory for this tensor.
+ // Returns the size of allocated memory for this tensor.
size_t AllocatedBytes() const;
/// Returns true iff this tensor is aligned.
@@ -199,10 +200,29 @@ class Tensor {
/// must check the returned tensor's alignment before calling certain
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
///
+ /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor
+ /// also with N dimensions. If you want to select a sub tensor, see SubSlice.
+ ///
/// REQUIRES: `dims()` >= 1
/// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
+ /// \brief Select a subslice from this tensor along the 1st dimension.
+ ///
+ /// When fed with an N-dimensional tensor, this method returns a tensor with
+ /// N-1 dimensions, where the returned tensor is a subslice of the input
+ /// tensor along the first dimension. The N-1 dimensions of the returned
+ /// tensor are the last N-1 dimensions of the input tensor.
+ ///
+ /// NOTE: The returned tensor may not satisfy the same alignment
+ /// requirement as this tensor depending on the shape. The caller
+ /// must check the returned tensor's alignment before calling certain
+ /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
+ ///
+ /// REQUIRES: `dims()` >= 2
+ /// REQUIRES: `0 <= dim0_start < dim_size(0)`
+ Tensor SubSlice(int64 index) const;
+
/// \brief Parse `other` and construct the tensor.
/// Returns `true` iff the parsing succeeds. If the parsing fails,
@@ -429,7 +449,7 @@ class Tensor {
int64 begin) const;
/// Render the first `max_entries` values in `*this` into a string.
- string SummarizeValue(int64 max_entries) const;
+ string SummarizeValue(int64 max_entries, bool print_v2 = false) const;
/// A human-readable summary of the tensor suitable for debugging.
string DebugString() const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 84a373c196..0bfa53e6c5 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/math/math_util.h"
@@ -1227,6 +1228,42 @@ TEST(Tensor, Slice_Basic) {
}
}
+TEST(Tensor, SubSlice_Basic) {
+ { // General
+ Tensor x(DT_FLOAT, TensorShape({10, 4, 36}));
+ // Fills in known values.
+ for (int i = 0; i < 10; ++i) {
+ x.SubSlice(i).flat<float>().setConstant(i * 1.f);
+ }
+ // A simple sub-slice along dim0.
+ Tensor y = x.SubSlice(5);
+ EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 36})));
+ auto tx = x.tensor<float, 3>();
+ auto ty = y.tensor<float, 2>();
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 36; ++k) {
+ EXPECT_EQ(ty(j, k), 5.0);
+ EXPECT_EQ(&tx(5, j, k), &ty(j, k));
+ }
+ }
+ }
+ {
+ // Test unaligned access via a SubSlice.
+ Tensor x(DT_FLOAT, TensorShape({30, 5}));
+ x.flat<float>().setConstant(0.0);
+
+ // Take an unaligned subslice.
+ Tensor y = x.SubSlice(1);
+#if EIGEN_MAX_ALIGN_BYTES > 0
+ EXPECT_FALSE(y.IsAligned());
+#endif
+ y.unaligned_flat<float>().setConstant(1.0);
+ for (int64 i = 0; i < y.NumElements(); ++i) {
+ EXPECT_EQ(1.0, y.unaligned_flat<float>()(i));
+ }
+ }
+}
+
template <typename T>
Tensor MkTensor(DataType dt, const TensorShape& shape,
std::vector<T> init_values) {
@@ -1294,6 +1331,63 @@ TEST(SummarizeValue, STRING) {
EXPECT_EQ("one two three four five one...", x.SummarizeValue(6));
}
+TEST(SummarizeValue, INT32_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, INT32Dims_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({3, 4}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ EXPECT_EQ("[[1 ... 4]\n ...\n [9 ... 12]]", x.SummarizeValue(1, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(10, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(-1, true));
+}
+
+TEST(SummarizeValue, FLOAT_PRINT_V2) {
+ Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, BOOL_PRINT_V2) {
+ Tensor x = MkTensor<bool>(DT_BOOL, TensorShape({5}), {false, true, true});
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[0 1 ... 0 1]", x.SummarizeValue(2, true));
+}
+
+TEST(SummarizeValue, STRING_PRINT_V2) {
+ Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(-1, true));
+ x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five one...]", x.SummarizeValue(6, true));
+}
+
void BM_CreateAndDestroy(int iters) {
TensorShape shape({10, 20});
while (--iters) {
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 4bda8f9eb8..a7cf600bab 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include <vector>
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 15b1add2c1..2e96b05787 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -39,6 +38,8 @@ limitations under the License.
namespace tensorflow {
+class Variant;
+
// MemoryType is used to describe whether input or output Tensors of
// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
index 5a507804b0..d43e3c72ec 100644
--- a/tensorflow/core/framework/variant.cc
+++ b/tensorflow/core/framework/variant.cc
@@ -23,11 +23,11 @@ limitations under the License.
namespace tensorflow {
-bool Variant::TryDecode(Variant* out) const {
- const VariantTensorDataProto* p = get<VariantTensorDataProto>();
- if (p == nullptr) return false;
- VariantTensorData data(*p);
- return out->Decode(data);
+bool Variant::Decode(VariantTensorData data) {
+ if (!is_empty()) {
+ return value_->Decode(std::move(data));
+ }
+ return true;
}
template <>
@@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) {
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data) {
- data->FromProto(value);
+ data->FromConstProto(value);
}
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value) {
- data.ToProto(value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) {
+ data->ToProto(value);
return true;
}
@@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) {
}
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
- return value->ParseFromString(buf);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value) {
+ return value->ParseFromString(*buf);
}
void EncodeVariantList(const Variant* variant_array, int64 n,
@@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
if (variant_array[i].is_empty()) {
variant_array[i] = VariantTensorDataProto();
}
+ // TODO(ebrevdo): Replace with StringPiece? Any way to make this a
+ // zero-copy operation that keeps a reference to the data in d?
string str(d->Data(sizes[i]), sizes[i]);
- if (!variant_array[i].Decode(str)) return false;
+ if (!variant_array[i].Decode(std::move(str))) return false;
if (!DecodeUnaryVariant(&variant_array[i])) {
LOG(ERROR) << "Could not decode variant with type_name: \""
<< variant_array[i].TypeName()
diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h
index 52732801a0..10eabbc85f 100644
--- a/tensorflow/core/framework/variant.h
+++ b/tensorflow/core/framework/variant.h
@@ -23,7 +23,6 @@ limitations under the License.
#include <unordered_map>
#include <utility>
-#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/status.h"
@@ -38,17 +37,19 @@ string TypeNameVariant(const T& value);
template <typename T>
string DebugStringVariant(const T& value);
+// Allows for specializations of Variant Decoding. `data` may be modified in
+// the process of decoding to `value`.
template <typename T>
-void EncodeVariant(const T& value, VariantTensorData* data);
+bool DecodeVariant(VariantTensorData* data, T* value);
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value);
+bool DecodeVariant(string* buf, T* value);
template <typename T>
-void EncodeVariant(const T& value, string* buf);
+void EncodeVariant(const T& value, VariantTensorData* data);
template <typename T>
-bool DecodeVariant(const string& buf, T* value);
+void EncodeVariant(const T& value, string* buf);
// This is an implementation of a type-erased container that can store an
// object of any type. The implementation is very similar to std::any, but has
@@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value);
//
// string TypeName() const;
// void Encode(VariantTensorData* data) const;
-// void Decode(const VariantTensorData& data);
+// void Decode(VariantTensorData data);
//
// Simple POD types can elide the Encode/Decode functions, they are provided by
// helper methods.
@@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value);
// x.Encode(&serialized_f);
//
// Variant y = Foo(); // default constructed Foo.
-// y.Decode(&serialized_f);
+// y.Decode(std::move(serialized_f));
// EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
//
//
@@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value);
// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo.
// EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(),
// y_type_unknown.TypeId());
-// // Decode and get y_type_unknown; compare to value in x.
-// Foo f_decoded;
-// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded));
-// EXPECT_EQ(f_decoded, f);
//
class Variant {
public:
@@ -241,12 +238,7 @@ class Variant {
}
// Deserialize `data` and update the stored object.
- bool Decode(const VariantTensorData& data) {
- if (!is_empty()) {
- return value_->Decode(data);
- }
- return true;
- }
+ bool Decode(VariantTensorData data);
// Helper methods to directly serialize/deserialize from strings.
void Encode(string* buf) const {
@@ -254,31 +246,13 @@ class Variant {
value_->Encode(buf);
}
}
- bool Decode(const string& buf) {
+ bool Decode(string buf) {
if (!is_empty()) {
- return value_->Decode(buf);
+ return value_->Decode(std::move(buf));
}
return true;
}
- template <typename T>
- bool MaybeDecodeAndCopy(T* out) const {
- const T* ret = get<T>();
- if (ret != nullptr) {
- *out = std::move(*ret);
- return true;
- };
- Variant decoded = T();
- if (!TryDecode(&decoded)) return false;
- T* decoded_ret = decoded.get<T>();
- CHECK_NOTNULL(decoded_ret);
- *out = std::move(*decoded_ret);
- return true;
- }
-
- private:
- bool TryDecode(Variant* out) const;
-
private:
struct in_place_t {};
static constexpr in_place_t in_place{};
@@ -292,9 +266,9 @@ class Variant {
virtual string TypeName() const = 0;
virtual string DebugString() const = 0;
virtual void Encode(VariantTensorData* data) const = 0;
- virtual bool Decode(const VariantTensorData& data) = 0;
+ virtual bool Decode(VariantTensorData data) = 0;
virtual void Encode(string* buf) const = 0;
- virtual bool Decode(const string& data) = 0;
+ virtual bool Decode(string data) = 0;
};
template <typename T>
@@ -325,15 +299,13 @@ class Variant {
EncodeVariant(value, data);
}
- bool Decode(const VariantTensorData& data) override {
- return DecodeVariant(data, &value);
+ bool Decode(VariantTensorData data) override {
+ return DecodeVariant(&data, &value);
}
void Encode(string* buf) const override { EncodeVariant(value, buf); }
- bool Decode(const string& buf) override {
- return DecodeVariant(buf, &value);
- }
+ bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
T value;
};
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
index f155aa4892..5e08e5a7a6 100644
--- a/tensorflow/core/framework/variant_encode_decode.h
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/abi.h"
@@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value,
// Specialization for POD type
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, true /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for tensorflow::Tensor
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, true /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for protobuf
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
true /* protobuf */>,
T* value) {
@@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for other types
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
- return value->Decode(data);
+ return value->Decode(std::move(data));
}
template <typename C, typename = void>
@@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) {
}
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value) {
- return DecodeVariantImpl(data, TypeResolver<T>(), value);
+bool DecodeVariant(VariantTensorData* data, T* value) {
+ return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
}
template <typename T>
@@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) {
}
template <typename T>
-bool DecodeVariant(const string& buf, T* value) {
+bool DecodeVariant(string* buf, T* value) {
VariantTensorData data;
- if (!data.ParseFromString(buf)) return false;
- if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false;
+ if (!data.ParseFromString(*buf)) return false;
+ if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
+ return false;
+ }
return true;
}
// Specializations for VariantTensorDataProto
template <>
string TypeNameVariant(const VariantTensorDataProto& value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data);
+
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value, string* buf);
+
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value);
// Encodes an array of Variant objects in to the given StringListEncoder.
// `variant_array` is assumed to point to an array of `n` Variant objects.
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index 60fa7bd559..daa744e877 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyCPUToGPU);
+ StoredTensorValue::CopyCPUToGPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST,
- "StoredTensorValue", StoredTensorValue::CopyGPUToCPU);
+ StoredTensorValue::CopyGPUToCPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyGPUToGPU);
+ StoredTensorValue::CopyGPUToGPU);
REGISTER_OP("CreateTestVariant")
.Input("input: T")
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index ee07db1aee..ef5b240aea 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
}
UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
- StringPiece type_name) {
- auto found = shape_fns.find(type_name);
+ const TypeIndex& type_index) {
+ auto found = shape_fns.find(type_index);
if (found == shape_fns.end()) return nullptr;
return &found->second;
}
-void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
+void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index,
const VariantShapeFn& shape_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
- VariantShapeFn* existing = GetShapeFn(type_name);
+ VariantShapeFn* existing = GetShapeFn(type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantShapeFn for type_name: " << type_name
- << " already registered";
- shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
- GetPersistentStringPiece(type_name), shape_fn));
+ << "Unary VariantShapeFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name()) << " already registered";
+ shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn));
}
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
@@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
CHECK_EQ(variant_tensor.dims(), 0);
const Variant& v = variant_tensor.scalar<Variant>()();
UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId());
if (shape_fn == nullptr) {
return errors::Internal(
- "No unary variant shape function found for Variant type_name: ",
- v.TypeName());
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(v.TypeId().name()));
}
return (*shape_fn)(v, shape);
}
@@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) {
} // namespace
#define REGISTER_VARIANT_SHAPE_TYPE(T) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>);
// No encode/shape registered for std::complex<> and Eigen::half
// objects yet.
@@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double);
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
UnaryVariantOpRegistry::GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name) {
- auto found = device_copy_fns.find(std::make_pair(direction, type_name));
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
+ auto found = device_copy_fns.find(std::make_pair(direction, type_index));
if (found == device_copy_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
- AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
+ AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
CHECK_EQ(existing, nullptr)
<< "UnaryVariantDeviceCopy for direction: " << direction
- << " and type_name: " << type_name << " already registered";
+ << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
+ << " already registered";
device_copy_fns.insert(
- std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn>(
- std::make_pair(direction, GetPersistentStringPiece(type_name)),
- device_copy_fn));
+ std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index),
+ device_copy_fn));
}
Status VariantDeviceCopy(
@@ -170,35 +167,34 @@ Status VariantDeviceCopy(
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
- from.TypeName());
+ from.TypeId());
if (device_copy_fn == nullptr) {
return errors::Internal(
"No unary variant device copy function found for direction: ",
- direction, " and Variant type_name: ", from.TypeName());
+ direction, " and Variant type_index: ",
+ port::MaybeAbiDemangle(from.TypeId().name()));
}
return (*device_copy_fn)(from, to, copy_fn);
}
// Special casing UnaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
- VariantUnaryOp op, StringPiece device, StringPiece type_name) {
- auto found = unary_op_fns.find({op, device, type_name});
+ VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
+ auto found = unary_op_fns.find({op, device, type_index});
if (found == unary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterUnaryOpFn(
- VariantUnaryOp op, const string& device, const string& type_name,
+ VariantUnaryOp op, const string& device, const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
- VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
+ VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantUnaryOpFn for type_name: " << type_name
+ << "Unary VariantUnaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- unary_op_fn));
+ {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
}
namespace {
@@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
- DEVICE_CPU, T, TF_STR(T), \
+ DEVICE_CPU, T, \
ZerosLikeVariantPrimitiveType<T>);
// No zeros_like registered for std::complex<> or Eigen::half objects yet.
@@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
// Special casing BinaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantBinaryOpFn*
UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name) {
- auto found = binary_op_fns.find({op, device, type_name});
+ const TypeIndex& type_index) {
+ auto found = binary_op_fns.find({op, device, type_index});
if (found == binary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterBinaryOpFn(
- VariantBinaryOp op, const string& device, const string& type_name,
+ VariantBinaryOp op, const string& device, const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
- VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
+ VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantBinaryOpFn for type_name: " << type_name
+ << "Unary VariantBinaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- add_fn));
+ {op, GetPersistentStringPiece(device), type_index}, add_fn));
}
namespace {
@@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
#define REGISTER_VARIANT_ADD_TYPE(T) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
- T, TF_STR(T), \
- AddVariantPrimitiveType<T>);
+ T, AddVariantPrimitiveType<T>);
// No add registered for std::complex<> or Eigen::half objects yet.
REGISTER_VARIANT_ADD_TYPE(int);
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index e6a2665a56..7eb37e859f 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -22,10 +22,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/abi.h"
namespace tensorflow {
@@ -90,10 +94,11 @@ class UnaryVariantOpRegistry {
AsyncVariantDeviceCopyFn;
// Add a shape lookup function to the registry.
- void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
+ void RegisterShapeFn(const TypeIndex& type_index,
+ const VariantShapeFn& shape_fn);
- // Returns nullptr if no shape function was found for the given TypeName.
- VariantShapeFn* GetShapeFn(StringPiece type_name);
+ // Returns nullptr if no shape function was found for the given TypeIndex.
+ VariantShapeFn* GetShapeFn(const TypeIndex& type_index);
// Add a decode function to the registry.
void RegisterDecodeFn(const string& type_name,
@@ -104,33 +109,33 @@ class UnaryVariantOpRegistry {
// Add a copy-to-GPU function to the registry.
void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
- const string& type_name,
+ const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn);
// Returns nullptr if no copy function was found for the given
// TypeName and direction.
AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name);
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index);
// Add a unary op function to the registry.
void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn);
// Returns nullptr if no unary op function was found for the given
// op, device, and TypeName.
VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Add a binary op function to the registry.
void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn);
// Returns nullptr if no binary op function was found for the given
// op, device and TypeName.
VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Get a pointer to a global UnaryVariantOpRegistry object
static UnaryVariantOpRegistry* Global();
@@ -145,24 +150,26 @@ class UnaryVariantOpRegistry {
static std::unordered_set<string>* PersistentStringStorage();
private:
- std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns;
- std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher>
- decode_fns;
+ struct TypeIndexHash {
+ std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
+ };
+
+ gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns;
+ gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
// Map std::pair<Direction, type_name> to function.
struct PairHash {
template <typename Direction>
- std::size_t operator()(const std::pair<Direction, StringPiece>& x) const {
+ std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
- ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
+ ret = Hash64Combine(ret, std::get<1>(x).hash_code());
return ret;
}
- StringPieceHasher sp_hasher_;
};
- std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn, PairHash>
+ gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn, PairHash>
device_copy_fns;
// Map std::tuple<Op, device, type_name> to function.
@@ -172,10 +179,11 @@ class UnaryVariantOpRegistry {
// and references therein
template <typename Op>
struct FuncTuple {
- FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
- : op_type_(op), device_(dev), typename_(tname){};
+ FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
+ : op_type_(op), device_(dev), type_index_(type_index) {}
Op op_type_;
- StringPiece device_, typename_;
+ StringPiece device_;
+ TypeIndex type_index_;
};
// friend declaration for operator==
// needed for clang
@@ -184,11 +192,11 @@ class UnaryVariantOpRegistry {
struct TupleHash {
template <typename Op>
std::size_t operator()(
- const std::tuple<Op, StringPiece, StringPiece>& x) const {
+ const std::tuple<Op, StringPiece, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
- ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
+ ret = Hash64Combine(ret, std::get<2>(x).hash_code());
return ret;
}
@@ -197,14 +205,14 @@ class UnaryVariantOpRegistry {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(x.op_type_);
ret = Hash64Combine(ret, sp_hasher_(x.device_));
- ret = Hash64Combine(ret, sp_hasher_(x.typename_));
+ ret = Hash64Combine(ret, x.type_index_.hash_code());
return ret;
}
StringPieceHasher sp_hasher_;
};
- std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
unary_op_fns;
- std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
binary_op_fns;
// Find or insert a string into a persistent string storage
@@ -225,7 +233,7 @@ template <typename Op>
inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
- (lhs.typename_ == rhs.typename_);
+ (lhs.type_index_ == rhs.type_index_);
}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
@@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
Variant* v_out) {
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
- UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
if (unary_op_fn == nullptr) {
return errors::Internal(
"No unary variant unary_op function found for unary variant op enum: ",
@@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
template <typename Device>
Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
const Variant& a, const Variant& b, Variant* out) {
- if (a.TypeName() != b.TypeName()) {
+ if (a.TypeId() != b.TypeId()) {
return errors::Internal(
"BianryOpVariants: Variants a and b have different "
- "type names: '",
+ "type ids. Type names: '",
a.TypeName(), "' vs. '", b.TypeName(), "'");
}
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
- UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
+ UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
if (binary_op_fn == nullptr) {
return errors::Internal(
"No unary variant binary_op function found for binary variant op "
@@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration {
public:
typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;
- UnaryVariantShapeRegistration(const string& type_name,
+ UnaryVariantShapeRegistration(const TypeIndex& type_index,
const LocalVariantShapeFn& shape_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterShapeFn(
- type_name,
- [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status {
+ type_index,
+ [type_index_name, shape_fn](const Variant& v,
+ TensorShape* s) -> Status {
const T* t = v.get<T>();
if (t == nullptr) {
return errors::Internal(
- "VariantShapeFn: Could not access object, type_name: ",
- type_name);
+ "VariantShapeFn: Could not access object, type_index: ",
+ type_index_name);
}
return shape_fn(*t, s);
});
@@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration {
return false;
}
Variant decoded = T();
- VariantTensorData data(*t);
- if (!decoded.Decode(data)) {
+ VariantTensorData data(std::move(*t));
+ if (!decoded.Decode(std::move(data))) {
return false;
}
- *v = std::move(decoded);
+ std::swap(decoded, *v);
return true;
});
}
@@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration {
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const LocalVariantDeviceCopyFn& device_copy_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
- direction, type_name,
- [type_name, device_copy_fn](
+ direction, type_index,
+ [type_index_name, device_copy_fn](
const Variant& from, Variant* to,
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
device_copy_tensor_fn) -> Status {
@@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration {
*to = T();
if (from.get<T>() == nullptr) {
return errors::Internal(
- "VariantCopyToGPUFn: Could not access object, type_name: ",
- type_name);
+ "VariantCopyToGPUFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *from.get<T>();
T* t_out = to->get<T>();
@@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration {
public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantUnaryOpFn& unary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
- op, device, type_name,
- [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
- Variant* v_out) -> Status {
+ op, device, type_index,
+ [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
+ Variant* v_out) -> Status {
DCHECK_NE(v_out, nullptr);
*v_out = T();
if (v.get<T>() == nullptr) {
return errors::Internal(
- "VariantUnaryOpFn: Could not access object, type_name: ",
- type_name);
+ "VariantUnaryOpFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *v.get<T>();
T* t_out = v_out->get<T>();
@@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration {
public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantBinaryOpFn& binary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
- op, device, type_name,
- [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
- const Variant& b, Variant* out) -> Status {
+ op, device, type_index,
+ [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
+ const Variant& b,
+ Variant* out) -> Status {
DCHECK_NE(out, nullptr);
*out = T();
if (a.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'a', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'a', type_index: ",
+ type_index_name);
}
if (b.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'b', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'b', type_index: ",
+ type_index_name);
}
const T& t_a = *a.get<T>();
const T& t_b = *b.get<T>();
@@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration {
// Register a unary shape variant function with the signature:
// Status ShapeFn(const T& t, TensorShape* s);
-// to Variants having TypeName type_name.
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \
- shape_function)
+// to Variants having TypeIndex type_index.
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, MakeTypeIndex<T>(), shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \
- shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function)
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \
+ shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \
shape_function) \
static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
- register_unary_variant_op_shape_registration_fn_##ctr(type_name, \
+ register_unary_variant_op_shape_registration_fn_##ctr(type_index, \
shape_function)
// Register a unary decode variant function for the given type.
@@ -519,63 +533,63 @@ class UnaryVariantBinaryOpRegistration {
// ****** NOTE ******
// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
// ****** NOTE ******
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- T, direction, type_name, device_copy_fn) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, T, direction, type_name, device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
+ device_copy_fn) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- ctr, T, direction, type_name, device_copy_fn) \
+ ctr, T, direction, type_index, device_copy_fn) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn)
+ ctr, T, direction, type_index, device_copy_fn)
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantDeviceCopyRegistration<T> \
- register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \
- device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
+ ctr, T, direction, type_index, device_copy_fn) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantDeviceCopyRegistration<T> \
+ register_unary_variant_op_device_copy_fn_##ctr( \
+ direction, type_index, device_copy_fn)
// Register a unary unary_op variant function with the signature:
// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for UnaryVariantOp enum op.
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
- unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
+ unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
- unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ ctr, op, device, T, type_index, unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
+ type_index, unary_op_function)
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, unary_op_function) \
+ ctr, op, device, T, type_index, unary_op_function) \
static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
unary_op_function)
// Register a binary_op variant function with the signature:
// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for BinaryVariantOp enum OP.
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
- binary_op_function) \
- REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, binary_op_function)
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
+ binary_op_function) \
+ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, binary_op_function) \
+ ctr, op, device, T, type_index, binary_op_function) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function)
+ ctr, op, device, T, type_index, binary_op_function)
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantBinaryOpRegistration<T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
+ ctr, op, device, T, type_index, binary_op_function) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantBinaryOpRegistration<T> \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
binary_op_function)
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 7055e62c0e..b2443e8676 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -89,41 +89,37 @@ struct VariantValue {
int value;
};
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
- VariantValue::ShapeFn);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "TEST VariantValue", VariantValue::CPUToGPUCopyFn);
+ VariantValue::CPUToGPUCopyFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, VariantValue,
- "TEST VariantValue",
VariantValue::CPUZerosLikeFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, VariantValue,
- "TEST VariantValue",
VariantValue::GPUZerosLikeFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- VariantValue, "TEST VariantValue",
- VariantValue::CPUAddFn);
+ VariantValue, VariantValue::CPUAddFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- VariantValue, "TEST VariantValue",
- VariantValue::GPUAddFn);
+ VariantValue, VariantValue::GPUAddFn);
} // namespace
TEST(VariantOpShapeRegistryTest, TestBasic) {
- EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"),
+ class Blah {};
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()),
nullptr);
- auto* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue");
+ auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn(
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(shape_fn, nullptr);
TensorShape shape;
@@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantShapeFn f;
- string kTypeName = "fjfjfj";
- registry.RegisterShapeFn(kTypeName, f);
- EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
- "fjfjfj already registered");
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
+ registry.RegisterShapeFn(kTypeIndex, f);
+ EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpDecodeRegistryTest, TestBasic) {
@@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
// No registered copy fn for GPU<->GPU.
- EXPECT_EQ(
- UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"),
- nullptr);
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
+ MakeTypeIndex<VariantValue>()),
+ nullptr);
auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE,
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(copy_to_gpu_fn, nullptr);
VariantValue vv{true /* early_exit */};
@@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
- kTypeName, f);
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f),
- "fjfjfj already registered");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantUnaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_CPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_CPU, kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_GPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_GPU, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
- return;
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
TEST(VariantOpAddRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantBinaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc
index 99712dc114..3e67e4a864 100644
--- a/tensorflow/core/framework/variant_tensor_data.cc
+++ b/tensorflow/core/framework/variant_tensor_data.cc
@@ -22,8 +22,8 @@ namespace tensorflow {
VariantTensorData::VariantTensorData() {}
-VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) {
- FromProto(proto);
+VariantTensorData::VariantTensorData(VariantTensorDataProto proto) {
+ FromProto(std::move(proto));
}
VariantTensorData::~VariantTensorData() {}
@@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const {
}
}
-bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) {
+bool VariantTensorData::FromProto(VariantTensorDataProto proto) {
+ // TODO(ebrevdo): Do this lazily.
+ set_type_name(proto.type_name());
+ set_metadata(proto.metadata());
+ for (const auto& tensor : proto.tensors()) {
+ Tensor tmp;
+ if (!tmp.FromProto(tensor)) return false;
+ tensors_.push_back(tmp);
+ }
+ return true;
+}
+
+bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) {
set_type_name(proto.type_name());
set_metadata(proto.metadata());
for (const auto& tensor : proto.tensors()) {
@@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) {
return proto.SerializeToString(buf);
}
-bool VariantTensorData::ParseFromString(const string& s) {
+bool VariantTensorData::ParseFromString(string s) {
VariantTensorDataProto proto;
const bool status = proto.ParseFromString(s);
- if (status) FromProto(proto);
+ if (status) FromProto(std::move(proto));
return status;
}
diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h
index 7500e77d43..8a240ee1e3 100644
--- a/tensorflow/core/framework/variant_tensor_data.h
+++ b/tensorflow/core/framework/variant_tensor_data.h
@@ -19,13 +19,13 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class VariantTensorDataProto;
-class Tensor;
// The serialization format for Variant objects. Objects with references to
// other Tensors can simply store those tensors in the `tensors` field, and
@@ -38,7 +38,7 @@ class Tensor;
class VariantTensorData {
public:
VariantTensorData();
- VariantTensorData(const VariantTensorDataProto& proto);
+ VariantTensorData(VariantTensorDataProto proto);
~VariantTensorData();
// Name of the type of objects being serialized.
@@ -68,12 +68,14 @@ class VariantTensorData {
// Conversion to and from VariantTensorDataProto
void ToProto(VariantTensorDataProto* proto) const;
- bool FromProto(const VariantTensorDataProto& proto);
+ // This allows optimizations via std::move.
+ bool FromProto(VariantTensorDataProto proto);
+ bool FromConstProto(const VariantTensorDataProto& proto);
// Serialization via VariantTensorDataProto
string SerializeAsString() const;
bool SerializeToString(string* buf);
- bool ParseFromString(const string& s);
+ bool ParseFromString(string s);
string DebugString() const;
diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc
index eef5c47d15..08d09de7b8 100644
--- a/tensorflow/core/framework/variant_test.cc
+++ b/tensorflow/core/framework/variant_test.cc
@@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) {
struct TensorList {
void Encode(VariantTensorData* data) const { data->tensors_ = vec; }
- bool Decode(const VariantTensorData& data) {
- vec = data.tensors_;
+ bool Decode(VariantTensorData data) {
+ vec = std::move(data.tensors_);
return true;
}
@@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) {
x.Encode(&serialized);
Variant y = TensorList();
- y.Decode(serialized);
+ y.Decode(std::move(serialized));
const TensorList& decoded_vec = *y.get<TensorList>();
for (int i = 0; i < 4; ++i) {
@@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) {
EXPECT_EQ(y_unknown.DebugString(),
strings::StrCat(
"Variant<type: TensorList value: ", data.DebugString(), ">"));
-
- TensorList unknown_decoded_vec;
- EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec));
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i);
- }
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i);
- }
}
TEST(VariantTest, VariantArray) {
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index ee10194142..eeb5c14eaa 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() {
}
if (processed < node_defs_.size()) {
- LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed)
+ LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed)
<< " NODES IN A CYCLE";
for (int64 i = 0; i < node_defs_.size(); i++) {
if (pending_count_[i] != 0) {
LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
- << "WITH PENDING COUNT = " << pending_count_[i];
+ << " WITH PENDING COUNT = " << pending_count_[i];
}
}
return errors::InvalidArgument(node_defs_.size() - processed,
@@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
const NodeDef* node_def = node_defs_[pair->second.gdef_index];
const OpDef* op_def;
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
- if (key.second >= op_def->output_arg_size()) {
+ int num_outputs;
+ TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
+ if (key.second >= num_outputs) {
// key's index out of bounds
missing_unused_input_map_keys_->push_back(key);
}
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index 73142ebde7..3eef6bd2bd 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput")
.Output("y: T")
.Attr("T: {float, int64}")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("TestVariadicOutput")
+ .Output("outputs: N * int32")
+ .Attr("N: int >= 0")
+ .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("TestDefaultAttr")
.Attr("default_int: int=31415")
.SetShapeFn(shape_inference::NoOutputs);
@@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0);
// Unused but not missing
opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0);
+ // Unused but not missing
+ opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
- node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+ node { name: 'variadic' op: 'TestVariadicOutput'
+ attr { key: "N" value { i: 5 } } }
)EOF",
opts, &refiner, &results);
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 2e644fe987..f5b0105862 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index bd0284d43a..b00196f587 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -32,7 +32,7 @@ namespace test {
namespace graph {
// Converts "g" into its corresponding GraphDef "def".
-// DEPRECATED: call g->ToGraphDef(def) instead.
+ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.")
void ToGraphDef(Graph* g, GraphDef* def);
// A few helpers to construct a graph.
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 7171ae059b..3b1d7d8347 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) {
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
rewriter_config->set_shape_optimization(RewriterConfig::OFF);
rewriter_config->set_remapping(RewriterConfig::OFF);
+ rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF);
rewriter_config->mutable_auto_parallel()->set_enable(false);
rewriter_config->clear_optimizers();
} else {
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index b97603c890..e4f6bf7c86 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -93,13 +93,13 @@ Status SingleMachine::Provision() {
strings::StrCat("Not able to parse GPU device name: ", dev.name()));
}
TfGpuId tf_gpu_id(parsed.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
return errors::Unavailable("Unknown TF GPU device with id ",
tf_gpu_id.value(), ": ", s.ToString());
}
- attr = GetLocalGPUInfo(cuda_gpu_id);
+ attr = GetLocalGPUInfo(platform_gpu_id);
} else if (dev.device_type().find("XLA") == string::npos) {
// Filter out the fake XLA devices to avoid double counting the actual
// hardware resources that are available.
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
index a7519725a5..567e7c075e 100644
--- a/tensorflow/core/grappler/clusters/utils.cc
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -70,13 +70,14 @@ DeviceProperties GetLocalCPUInfo() {
return device;
}
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id) {
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) {
DeviceProperties device;
device.set_type("GPU");
#if GOOGLE_CUDA
cudaDeviceProp properties;
- cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value());
+ cudaError_t error =
+ cudaGetDeviceProperties(&properties, platform_gpu_id.value());
if (error != cudaSuccess) {
device.set_type("UNKNOWN");
LOG(ERROR) << "Failed to get device properties, error code: " << error;
@@ -122,15 +123,15 @@ DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) {
} else if (device.type == "GPU") {
if (device.has_id) {
TfGpuId tf_gpu_id(device.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
LOG(ERROR) << s;
return unknown;
}
- return GetLocalGPUInfo(cuda_gpu_id);
+ return GetLocalGPUInfo(platform_gpu_id);
} else {
- return GetLocalGPUInfo(CudaGpuId(0));
+ return GetLocalGPUInfo(PlatformGpuId(0));
}
}
return unknown;
diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h
index ca15c48006..f0a342b728 100644
--- a/tensorflow/core/grappler/clusters/utils.h
+++ b/tensorflow/core/grappler/clusters/utils.h
@@ -28,7 +28,7 @@ DeviceProperties GetLocalCPUInfo();
// Returns the DeviceProperties for the specified GPU attached to the server on
// which grappler is running.
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id);
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id);
// Returns the DeviceProperties of the specified device
DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device);
diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc
index 74218adbac..3863d62980 100644
--- a/tensorflow/core/grappler/clusters/utils_test.cc
+++ b/tensorflow/core/grappler/clusters/utils_test.cc
@@ -31,22 +31,22 @@ TEST(UtilsTest, GetLocalGPUInfo) {
LOG(INFO) << "CUDA is enabled.";
DeviceProperties properties;
- // Invalid CUDA GPU ID.
- properties = GetLocalGPUInfo(CudaGpuId(100));
+ // Invalid platform GPU ID.
+ properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("UNKNOWN", properties.type());
- // Succeed when a valid CUDA GPU id was inserted.
- properties = GetLocalGPUInfo(CudaGpuId(0));
+ // Succeed when a valid platform GPU id was inserted.
+ properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
EXPECT_EQ("NVIDIA", properties.vendor());
#else
LOG(INFO) << "CUDA is not enabled.";
DeviceProperties properties;
- properties = GetLocalGPUInfo(CudaGpuId(0));
+ properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
- properties = GetLocalGPUInfo(CudaGpuId(100));
+ properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("GPU", properties.type());
#endif
}
@@ -74,20 +74,20 @@ TEST(UtilsTest, GetDeviceInfo) {
EXPECT_EQ("NVIDIA", properties.vendor());
#endif
- // TF to CUDA GPU id mapping entry doesn't exist.
+ // TF to platform GPU id mapping entry doesn't exist.
device.has_id = true;
device.id = 0;
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
#if GOOGLE_CUDA
- // Invalid CUDA GPU id.
- GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(0), CudaGpuId(100));
+ // Invalid platform GPU id.
+ GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(0), PlatformGpuId(100));
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
- // Valid CUDA GPU id.
- GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(1), CudaGpuId(0));
+ // Valid platform GPU id.
+ GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(1), PlatformGpuId(0));
device.id = 1;
properties = GetDeviceInfo(device);
EXPECT_EQ("GPU", properties.type());
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d24e7e8ee4..56c8339d57 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -260,13 +260,13 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
}
bool IsEnqueue(const NodeDef& n) {
- return (n.op().find("Enqueue") != std::string::npos &&
- n.op().find("EnqueueMany") == std::string::npos);
+ return (n.op().find("Enqueue") != string::npos &&
+ n.op().find("EnqueueMany") == string::npos);
}
bool IsDequeue(const NodeDef& n) {
- return (n.op().find("Dequeue") != std::string::npos &&
- n.op().find("DequeueMany") == std::string::npos);
+ return (n.op().find("Dequeue") != string::npos &&
+ n.op().find("DequeueMany") == string::npos);
}
bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
@@ -345,6 +345,56 @@ void VerboseLogUnknownDimensionSources(
}
}
+bool IsShapeFullyDefinedIntegerVectorOrScalar(
+ InferenceContext* ic, const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape, const DataType& dtype) {
+ if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
+ !ic->FullyDefined(tensor_as_shape) ||
+ (dtype != DT_INT32 && dtype != DT_INT64)) {
+ return false;
+ }
+ return true;
+}
+
+// Returned tensor's shape is like `shape`, and its values and dtype are from
+// `tensor_as_shape` and `dtype`.
+TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
+ const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape,
+ const DataType& dtype) {
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(dtype);
+ auto* shape_proto = tensor_proto.mutable_tensor_shape();
+ if (ic->Rank(shape) == 1) {
+ shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
+ }
+ // For a scalar tensor, tensor_shape field will be left empty; no dim.
+ for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
+ int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
+ if (dtype == DT_INT32) {
+ tensor_proto.add_int_val(value);
+ } else {
+ tensor_proto.add_int64_val(value);
+ }
+ }
+ return tensor_proto;
+}
+
+// Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
+// and dtype = `dtype`.
+NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
+ const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape,
+ const DataType& dtype) {
+ NodeDef const_node;
+ const_node.set_name("const_from_shape");
+ const_node.set_op("Const");
+ auto* attr = const_node.mutable_attr();
+ (*attr)["dtype"].set_type(dtype);
+ auto* tensor = (*attr)["value"].mutable_tensor();
+ *tensor = MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype);
+ return const_node;
+}
} // namespace
// Queue of nodes to process. Nodes can be enqueued in any order, but will be
@@ -494,14 +544,26 @@ class SymbolicShapeRefiner {
// Replace input Placeholders with Consts, if values are known. Note that
// we don't check exceptions here as it's done in the above loop.
+ auto* ctx = GetNodeContext(function_node);
+ auto* ic = ctx->inference_context.get();
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
const string& input = function_node->input(i);
const string& node_name = NodeName(input);
NodeDef* input_node = graph_.GetNode(node_name);
- // TODO(dyoon): also use Const when output_tensors_as_shape is available.
if (IsConstant(*input_node)) {
TF_CHECK_OK(
ReplaceInputWithConst(*input_node, i, &grappler_function_item));
+ } else if (ic->input_tensors_as_shapes().size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i])) {
+ // We have fully defined input_tensors_as_shapes for this input; use it
+ // as a const input to the function node.
+ NodeDef const_input_node = MakeConstNodeDefFromShape(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i]);
+ TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
+ &grappler_function_item));
}
}
@@ -510,8 +572,8 @@ class SymbolicShapeRefiner {
TF_RETURN_IF_ERROR(gp.InferStatically(true));
// Add return nodes for output shapes.
- auto ic = GetContext(function_node);
int output = 0;
+ ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_tensors.size() > 1) {
// TODO(jmdecker): Handle case of multiple output tensors
@@ -544,6 +606,14 @@ class SymbolicShapeRefiner {
ShapeHandle out;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
ic->set_output(output, out);
+ if (outprop.has_value()) {
+ // Forward tensor value to output_tensors_as_shape.
+ Tensor tensor;
+ if (tensor.FromProto(outprop.value())) {
+ MaybeSetTensorValueToShape(ic, tensor,
+ &ctx->output_tensors_as_shapes[output]);
+ }
+ }
output++;
}
@@ -586,21 +656,9 @@ class SymbolicShapeRefiner {
if (const_values[dst_input].FromProto(
input->attr().at("value").tensor())) {
input_tensors[dst_input] = &const_values[dst_input];
- // Integer tensors of rank one can also be interpreted as a shape
- // provided all their values are >= -1.
- if (const_values[dst_input].dims() == 1 &&
- (const_values[dst_input].dtype() == DT_INT32 ||
- const_values[dst_input].dtype() == DT_INT64)) {
- ShapeHandle tensor_shape = inference_context->Vector(
- const_values[dst_input].NumElements());
- ShapeHandle shp;
- if (inference_context
- ->MakeShapeFromTensor(input_tensors[dst_input],
- tensor_shape, &shp)
- .ok()) {
- input_tensors_as_shapes[dst_input] = shp;
- }
- }
+ MaybeSetTensorValueToShape(inference_context,
+ const_values[dst_input],
+ &input_tensors_as_shapes[dst_input]);
}
} else if (IsRank(*input)) {
if (c->inference_context->RankKnown(c->inference_context->input(0))) {
@@ -968,13 +1026,25 @@ class SymbolicShapeRefiner {
: t->scalar<int64>()();
dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
} else {
- dims.push_back(ic->UnknownDim());
+ // Don't have tensor value, but use input_tensors_as_shapes, if
+ // possible.
+ const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
+ if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
+ ic->ValueKnown(ic->Dim(shape_handle, 0))) {
+ dims.push_back(ic->Dim(shape_handle, 0));
+ } else {
+ dims.push_back(ic->UnknownDim());
+ }
}
}
if (valid) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
}
+ } else if (IsIdentity(node)) {
+ // Pass input_tensors_as_shapes to output_tensors_as_shapes.
+ c->output_tensors_as_shapes.resize(1);
+ c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
} else if (IsSlice(node)) {
ShapeHandle input = ic->input_tensors_as_shapes()[0];
bool valid = ic->RankKnown(input);
@@ -1079,6 +1149,46 @@ class SymbolicShapeRefiner {
}
private:
+ bool IsIntegerVector(const Tensor& tensor) {
+ if (tensor.dims() == 1 &&
+ (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
+ return true;
+ }
+ return false;
+ }
+
+ bool IsIntegerScalar(const Tensor& tensor) {
+ if (tensor.dims() == 0 &&
+ (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
+ tensor.NumElements() == 1) {
+ return true;
+ }
+ return false;
+ }
+
+ void MaybeSetTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
+ ShapeHandle* tensors_as_shapes) {
+ // Integer tensors of rank one can also be interpreted as a shape
+ // provided all their values are >= -1.
+ if (IsIntegerVector(tensor)) {
+ ShapeHandle tensor_shape = ic->Vector(tensor.NumElements());
+ ShapeHandle shp;
+ // Note that MakeShapeFromTensor filters out invalid values (e.g., < -1).
+ if (ic->MakeShapeFromTensor(&tensor, tensor_shape, &shp).ok()) {
+ *tensors_as_shapes = shp;
+ }
+ } else if (IsIntegerScalar(tensor)) {
+ // Scalar constant.
+ int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
+ : tensor.flat<int64>()(0);
+ // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
+ // It's a limitation as we use ShapeHandle as a means to pass values.
+ if (value >= -1) {
+ *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
+ }
+ }
+ }
+
const GraphView& graph_;
int graph_def_version_;
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
@@ -1554,6 +1664,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
continue;
}
+ auto* ic = ctx->inference_context.get();
+
// Fill input properties.
{
auto& input_properties = input_properties_[node.name()];
@@ -1561,19 +1673,26 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(input_properties.size(), 0);
- input_properties.resize(ctx->inference_context->num_inputs());
+ input_properties.resize(ic->num_inputs());
GraphView::InputPort input(&node, -1);
- for (int i = 0; i < ctx->inference_context->num_inputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->inference_context->input(i),
- ctx->input_types[i],
+ for (int i = 0; i < ic->num_inputs(); ++i) {
+ shape_manager.AsTensorProperties(ic->input(i), ctx->input_types[i],
&input_properties[i]);
input.port_id = i;
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
- if (!IsConstant(*fanin.node)) {
- continue;
+ // Export tensor value (either const tensor or input_tensors_as_shapes)
+ // to input_properties.value.
+ if (IsConstant(*fanin.node)) {
+ const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
+ *input_properties[i].mutable_value() = raw_val;
+ } else if (ic->input_tensors_as_shapes().size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i])) {
+ *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i]);
}
- const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
- *input_properties[i].mutable_value() = raw_val;
}
}
@@ -1584,11 +1703,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(output_properties.size(), 0);
- output_properties.resize(ctx->inference_context->num_outputs());
- for (int i = 0; i < ctx->inference_context->num_outputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->inference_context->output(i),
- ctx->output_types[i],
+ output_properties.resize(ic->num_outputs());
+ for (int i = 0; i < ic->num_outputs(); ++i) {
+ shape_manager.AsTensorProperties(ic->output(i), ctx->output_types[i],
&output_properties[i]);
+ // Export tensor value (either const tensor or input_tensors_as_shapes)
+ // to output_properties.value.
+ if (IsConstant(node)) {
+ const TensorProto& raw_val = node.attr().at("value").tensor();
+ *output_properties[i].mutable_value() = raw_val;
+ } else if (ctx->output_tensors_as_shapes.size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+ ctx->output_types[i])) {
+ *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
+ ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+ ctx->output_types[i]);
+ }
}
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 3ec68a4e59..362092a6cf 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -44,6 +44,30 @@ class GraphPropertiesTest : public ::testing::Test {
// Provision a single machine with 3 cpu cores
cluster_.reset(new SingleMachine(5 * 60, 3, 0));
TF_CHECK_OK(cluster_->Provision());
+
+ // This function is simply
+ // out = Fill(shape, value), but
+ // Fill requires values in the shape input, not just shape of it, to infer
+ // output shape.
+ auto f = FunctionDefHelper::Create(
+ // Name
+ "MyFillFunc",
+ // Inputs
+ {"shape: int32", "value: float"},
+ // Outputs
+ {"out: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"a"},
+ "Fill",
+ {"shape", "value"},
+ {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
+ },
+ // Returns
+ {{"out", "a:output:0"}});
+ function_lib_.add_function()->Swap(&f);
}
void TearDown() override {
@@ -69,7 +93,29 @@ class GraphPropertiesTest : public ::testing::Test {
return s;
}
+ // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
+ // ones.
+ void ExpectTensorValues(const std::vector<int64>& expected,
+ const TensorProto& tensor_proto_to_compare) {
+ Tensor tensor;
+ EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare));
+ EXPECT_EQ(expected.size(), tensor.NumElements());
+ // We're interested in only integer tensors as only shapes are exported as
+ // graph properties values.
+ CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
+ if (tensor.dtype() == DT_INT32) {
+ for (int i = 0; i < tensor.NumElements(); i++) {
+ EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
+ }
+ } else {
+ for (int i = 0; i < tensor.NumElements(); i++) {
+ EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
+ }
+ }
+ }
+
std::unique_ptr<SingleMachine> cluster_;
+ FunctionDefLibrary function_lib_;
};
TEST_F(GraphPropertiesTest, StaticProperties) {
@@ -785,32 +831,138 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
+TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
+ Output a1 = ops::Identity(s.WithOpName("a1"), a);
+ Output b = ops::Const(s.WithOpName("b"), 99, {});
+ Output b1 = ops::Identity(s.WithOpName("b1"), b);
+ Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
+ Output c1 = ops::Identity(s.WithOpName("c1"), c);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ // Check output shapes.
+ EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
+ EXPECT_EQ("int32: [2]",
+ PropToString(properties.GetOutputProperties("a1")[0]));
+ EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
+ EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
+ EXPECT_EQ("int32: [4,4,4]",
+ PropToString(properties.GetOutputProperties("c")[0]));
+ EXPECT_EQ("int32: [4,4,4]",
+ PropToString(properties.GetOutputProperties("c1")[0]));
+
+ // Check has_value.
+ EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
+ // Note that we propagate tensro value of only 1D vector and scalar.
+ EXPECT_FALSE(properties.GetOutputProperties("c1")[0].has_value());
+
+ // Check values.
+ ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
+ ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
+ ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
+ ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
+ ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
+ ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
+ std::vector<int64> c_values;
+ for (int i = 0; i < 4 * 4 * 4; i++) {
+ c_values.push_back(1);
+ }
+ ExpectTensorValues({c_values},
+ properties.GetOutputProperties("c")[0].value());
+ ExpectTensorValues({c_values},
+ properties.GetInputProperties("c1")[0].value());
+ // No output value for c1, as it's neither 1D vector nor scalar.
+}
+
+TEST_F(GraphPropertiesTest, IdentityPassingShape) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 5, {2});
+ Output b = ops::Identity(s.WithOpName("b"), a);
+ Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also the value of e to figure out output
+ // shape; hence, Identity op (b) should pass a's value as
+ // output_tensors_as_shape.
+ Output d = ops::Fill(s.WithOpName("fill"), b, c);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithConstInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {});
+ Output b = ops::Const(s.WithOpName("b"), 2, {});
+ Output c = ops::Const(s.WithOpName("c"), 3, {});
+ Output d = ops::Const(s.WithOpName("d"), 4, {});
+ // Note ops::Stack instantiates Pack op.
+ Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+ // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+ Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also its value to figure out output
+ // shape.
+ Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
+ // from Const.
+ // If output_tensors_as_shape is not not set for those Shape ops or Pack op
+ // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
+ // hence, its output shape becomes unknown.
+ Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
+ Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
+ Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
+ Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
+ Output a = ops::Identity(s.WithOpName("a"), a0);
+ Output b = ops::Identity(s.WithOpName("b"), b0);
+ Output c = ops::Identity(s.WithOpName("c"), c0);
+ Output d = ops::Identity(s.WithOpName("d"), d0);
+ // Note ops::Stack instantiates Pack op.
+ Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+ // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+ Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also its value to figure out output
+ // shape.
+ Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
- FunctionDefLibrary library;
- // This function is simply
- // out = Fill(shape, value), but
- // Fill requires values in the shape input, not just shape of it, to infer
- // output shape; hence, func
- *library.add_function() = FunctionDefHelper::Create(
- // Name
- "MyFillFunc",
- // Inputs
- {"shape: int32", "value: float"},
- // Outputs
- {"out: float"},
- // Attrs
- {},
- // Nodes
- {
- {{"a"},
- "Fill",
- {"shape", "value"},
- {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
- },
- // Returns
- {{"out", "a:output:0"}});
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
@@ -827,13 +979,69 @@ TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyFillFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
- EXPECT_FALSE(out_prop0.shape().unknown_rank());
- EXPECT_EQ(4, out_prop0.shape().dim_size());
- EXPECT_EQ(1, out_prop0.shape().dim(0).size());
- EXPECT_EQ(2, out_prop0.shape().dim(1).size());
- EXPECT_EQ(3, out_prop0.shape().dim(2).size());
- EXPECT_EQ(4, out_prop0.shape().dim(3).size());
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
+ // Same to FunctionWithConstInput, but function inputs are Identity of Const,
+ // so tensor shapes, not tensor value, should be used as Const input to
+ // function.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
+ Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
+ Output shape = ops::Identity(s.WithOpName("shape"), shape_);
+ Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+ auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+ s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto _value = tensorflow::ops::AsNodeOut(s, value);
+ TF_CHECK_OK(
+ builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFillFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
+ FunctionDefLibrary library;
+ *library.add_function() = FunctionDefHelper::Create(
+ "MyFunc", // Name
+ {"x: int32"}, // Inputs
+ {"out: int32"}, // Outputs
+ {}, // Attrs
+ {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes
+ {{"out", "a:output:0"}}); // Returns
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+
+ // MyFunc takes Const (shape) and passes it with Identity. Expect function
+ // output has the same shape as well as value (output_tensors_as_shape) as
+ // input Const tensor.
+ Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto builder =
+ tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(true));
+ const auto out_props = properties.GetOutputProperties("MyFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("int32: [2]", PropToString(out_prop0));
+ EXPECT_TRUE(out_prop0.has_value());
+ ExpectTensorValues({5, 7}, out_prop0.value());
+ ExpectTensorValues({5, 7},
+ properties.GetInputProperties("MyFunc")[0].value());
}
TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
@@ -907,18 +1115,10 @@ TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
@@ -933,51 +1133,25 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
EXPECT_EQ(2, out_props.size());
const OpInfo::TensorProperties& out_prop0 = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
- EXPECT_EQ(4, out_prop0.shape().dim_size());
- EXPECT_EQ(128, out_prop0.shape().dim(0).size());
- EXPECT_EQ(112, out_prop0.shape().dim(1).size());
- EXPECT_EQ(112, out_prop0.shape().dim(2).size());
- EXPECT_EQ(64, out_prop0.shape().dim(3).size());
+ EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
const OpInfo::TensorProperties& out_prop1 = out_props[1];
- EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
- EXPECT_EQ(128, out_prop1.shape().dim(0).size());
- EXPECT_EQ(112, out_prop1.shape().dim(1).size());
- EXPECT_EQ(112, out_prop1.shape().dim(2).size());
- EXPECT_EQ(24, out_prop1.shape().dim(3).size());
+ EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
const auto in_props = properties.GetInputProperties("y0");
EXPECT_EQ(4, in_props.size());
const OpInfo::TensorProperties& in_prop0 = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop0.dtype());
- EXPECT_EQ(1, in_prop0.shape().dim_size());
- EXPECT_EQ(64, in_prop0.shape().dim(0).size());
+ EXPECT_EQ("float: [64]", PropToString(in_prop0));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_EQ(4, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(1, in_prop1.shape().dim(1).size());
- EXPECT_EQ(24, in_prop1.shape().dim(2).size());
- EXPECT_EQ(64, in_prop1.shape().dim(3).size());
+ EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
const OpInfo::TensorProperties& in_prop2 = in_props[2];
- EXPECT_EQ(DT_FLOAT, in_prop2.dtype());
- EXPECT_EQ(4, in_prop2.shape().dim_size());
- EXPECT_EQ(128, in_prop2.shape().dim(0).size());
- EXPECT_EQ(224, in_prop2.shape().dim(1).size());
- EXPECT_EQ(224, in_prop2.shape().dim(2).size());
- EXPECT_EQ(3, in_prop2.shape().dim(3).size());
+ EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
const OpInfo::TensorProperties& in_prop3 = in_props[3];
- EXPECT_EQ(DT_FLOAT, in_prop3.dtype());
- EXPECT_EQ(4, in_prop3.shape().dim_size());
- EXPECT_EQ(7, in_prop3.shape().dim(0).size());
- EXPECT_EQ(7, in_prop3.shape().dim(1).size());
- EXPECT_EQ(3, in_prop3.shape().dim(2).size());
- EXPECT_EQ(8, in_prop3.shape().dim(3).size());
+ EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
}
TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
@@ -1037,18 +1211,10 @@ TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
@@ -1073,27 +1239,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
const OpInfo::TensorProperties& out_prop = out_props[0];
EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
@@ -1117,28 +1272,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
const OpInfo::TensorProperties& out_prop = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
@@ -1166,28 +1309,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
const OpInfo::TensorProperties& out_prop = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(3, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, SymbolicShapes) {
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index aad00ce039..5415324b48 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -127,7 +127,7 @@ static void ExtractExtraProperties(
// For filename input, the file size can also be useful.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("filename") != std::string::npos) {
+ op_def->input_arg(i).name().find("filename") != string::npos) {
Tensor tensor;
if (!tensor.FromProto(t)) {
continue;
@@ -153,7 +153,7 @@ static void ExtractExtraProperties(
// When the input is a handle (e.g. look up table handle), the information
// in the op itself is not sufficient to predict the op memory.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("handle") != std::string::npos) {
+ op_def->input_arg(i).name().find("handle") != string::npos) {
string new_key = strings::StrCat("parent_", i, "_op");
AttrValue attr;
attr.set_s(input_node->op());
@@ -209,13 +209,13 @@ DeviceProperties GetDeviceInfo(const string& device_str) {
if (DeviceNameUtils::ParseFullName(device_str, &parsed)) {
if (parsed.type == "GPU") {
TfGpuId tf_gpu_id(parsed.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
// We are probably running simulation without linking cuda libraries.
- cuda_gpu_id = CudaGpuId(parsed.id);
+ platform_gpu_id = PlatformGpuId(parsed.id);
}
- return GetLocalGPUInfo(cuda_gpu_id);
+ return GetLocalGPUInfo(platform_gpu_id);
} else if (parsed.type == "CPU") {
return GetLocalCPUInfo();
}
@@ -320,8 +320,8 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) {
buckets_.begin(), std::plus<uint64>());
}
-std::string TensorSizeHistogram::ToString() const {
- std::string r;
+string TensorSizeHistogram::ToString() const {
+ string r;
char buf[200];
snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_);
r.append(buf);
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index d2c7c67666..5fd6717712 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -80,7 +80,7 @@ class TensorSizeHistogram {
uint64 Max() const { return max_; }
uint64 NumElem() const { return num_elem_; }
uint64 SumElem() const { return sum_elem_; }
- std::string ToString() const;
+ string ToString() const;
protected:
const int Index(const uint64 value) const;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 02a379fca8..80889afc86 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -1999,13 +1999,13 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
// Helper lambda to extract port num from _Send and _Recv op name.
auto get_port_num = [](const string& name) -> int {
- if (name.find("bn_0") != std::string::npos) {
+ if (name.find("bn_0") != string::npos) {
return 0;
- } else if (name.find("bn_1") != std::string::npos) {
+ } else if (name.find("bn_1") != string::npos) {
return 1;
- } else if (name.find("bn_2") != std::string::npos) {
+ } else if (name.find("bn_2") != string::npos) {
return 2;
- } else if (name.find("bn_minus1") != std::string::npos) {
+ } else if (name.find("bn_minus1") != string::npos) {
return -1;
}
return -999;
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index a6b6b6f8b2..2619a9a8f3 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -14,11 +14,44 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
+ ++output_arg_id) {
+ if (port_id < 0) {
+ return -1;
+ } else if (port_id == 0) {
+ return output_arg_id;
+ }
+
+ // Default is 1 port per output arg.
+ int n = 1;
+
+ const auto& output_arg = op.output_arg(output_arg_id);
+ if (!output_arg.number_attr().empty()) {
+ n = node.attr().at(output_arg.number_attr()).i();
+ } else if (!output_arg.type_list_attr().empty()) {
+ n = node.attr().at(output_arg.type_list_attr()).list().type_size();
+ }
+
+ if (n < 0) {
+ // This should never happen.
+ DCHECK_GE(n, 0);
+ return -1;
+ } else if (port_id < n) {
+ return output_arg_id;
+ }
+ port_id -= n;
+ }
+
+ return -1;
+}
+
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ac260f85a0..ec946ca3b5 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -20,11 +20,21 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
+// Map a node/op's output port_id to arg_id.
+//
+// The port_id refers to the n-th tensor of the node, while the arg_id refers to
+// the n-th arg of the op. These two can be different if an op's arg is a list
+// of tensors.
+//
+// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
public:
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 958eb921fb..3d7d2faf7c 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -25,6 +26,88 @@ namespace {
class GraphViewTest : public ::testing::Test {};
+TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& a_node_def = *graph_view.GetNode("a");
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+
+ const OpDef* a_op_def = nullptr;
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok());
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
+}
+
+TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
+ for (int num_splits : {1, 2}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
+ ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits);
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
+ int arg_id = -1;
+ if (port_id < num_splits * 3) {
+ arg_id = port_id / num_splits;
+ }
+ EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id));
+ }
+ }
+}
+
+TEST_F(GraphViewTest, ParseSingleExample) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<string>(s.WithOpName("a"), "", {});
+ Output b = ops::Const<int64>(s.WithOpName("b"), 1, {1, 1});
+ ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"},
+ {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& c_node_def = *graph_view.GetNode("c");
+
+ const OpDef* c_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(c_node_def.op(), &c_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 1));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 2));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 3));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 4));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 5));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 6));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 7));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 8));
+}
+
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;
diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc
index 5029dff877..def9198a69 100644
--- a/tensorflow/core/grappler/inputs/utils.cc
+++ b/tensorflow/core/grappler/inputs/utils.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/inputs/utils.h"
-#include "tensorflow/core/platform/env.h"
#include <vector>
+#include "tensorflow/core/platform/env.h"
+
namespace tensorflow {
namespace grappler {
@@ -29,12 +30,12 @@ bool FilesExist(const std::set<string>& files) {
return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr);
}
-bool FileExists(const std::string& file, Status* status) {
+bool FileExists(const string& file, Status* status) {
*status = Env::Default()->FileExists(file);
return status->ok();
}
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result) {
Status status;
if (FileExists(graph_def_pbtxt_path, &status)) {
diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h
index 627dd5359f..4b9cb0a9ad 100644
--- a/tensorflow/core/grappler/inputs/utils.h
+++ b/tensorflow/core/grappler/inputs/utils.h
@@ -29,9 +29,9 @@ bool FilesExist(const std::vector<string>& files,
std::vector<Status>* status = nullptr);
bool FilesExist(const std::set<string>& files);
-bool FileExists(const std::string& file, Status* status);
+bool FileExists(const string& file, Status* status);
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result);
} // end namespace grappler
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index e78239bd43..3521669b63 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -491,7 +491,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
}
}
// Queue ops modify the queue which is a side effect.
- if (node.op().find("Queue") != std::string::npos) {
+ if (node.op().find("Queue") != string::npos) {
return false;
}
return !ModifiesInputsInPlace(node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index f094c151e6..960d1addb3 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -8,10 +8,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# Platform specific build config
load(
- "//tensorflow/core:platform/default/build_config.bzl",
- "tf_protos_grappler",
-)
-load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
)
@@ -97,7 +93,6 @@ cc_library(
deps = [
":evaluation_utils",
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -107,6 +102,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -261,7 +257,6 @@ cc_library(
":constant_folding",
":graph_optimizer",
":graph_optimizer_stage",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -270,6 +265,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -515,12 +511,14 @@ cc_library(
":custom_graph_optimizer_registry",
":debug_stripper",
":dependency_optimizer",
+ ":experimental_implementation_selector",
":function_optimizer",
":graph_optimizer",
":layout_optimizer",
":loop_optimizer",
":memory_optimizer",
":model_pruner",
+ ":pin_to_host_optimizer",
":remapper",
":scoped_allocator_optimizer",
":shape_optimizer",
@@ -647,7 +645,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -657,6 +654,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -714,31 +712,6 @@ tf_cuda_cc_test(
)
cc_library(
- name = "symbolic_shapes",
- srcs = ["symbolic_shapes.cc"],
- hdrs = ["symbolic_shapes.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ] + tf_protos_grappler(),
-)
-
-tf_cc_test(
- name = "symbolic_shapes_test",
- srcs = ["symbolic_shapes_test.cc"],
- deps = [
- ":symbolic_shapes",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_library(
name = "debug_stripper",
srcs = ["debug_stripper.cc"],
hdrs = [
@@ -911,3 +884,41 @@ tf_cc_test(
"//tensorflow/core/grappler/utils:grappler_test",
],
)
+
+cc_library(
+ name = "pin_to_host_optimizer",
+ srcs = ["pin_to_host_optimizer.cc"],
+ hdrs = [
+ "pin_to_host_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "pin_to_host_optimizer_test",
+ srcs = ["pin_to_host_optimizer_test.cc"],
+ deps = [
+ ":pin_to_host_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 11ce121cba..75ed12635e 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -35,8 +35,8 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -1325,38 +1325,26 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const string node_name = node->name();
NodeDef* x;
NodeDef* y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
bool updated = false;
- if (IsAdd(*node)) {
- if (IsNeg(*x)) {
- // (-a) + b = b - a
- node->set_op("Sub");
- node->mutable_input()->SwapElements(0, 1);
- node->set_input(1, x->input(0));
- node->add_input(AsControlDependency(x->name()));
- ctx().node_map->AddOutput(NodeName(x->input(0)), node_name);
- updated = true;
- } else if (IsNeg(*y)) {
- // a + (-b) = a - b
- node->set_op("Sub");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
- } else if (IsSub(*node)) {
- if (IsNeg(*y)) {
- // a - (-b) = a + b
- node->set_op("Add");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
+ if (IsNeg(*y)) {
+ // a - (-b) = a + b or a + (-b) = a - b
+ ForwardControlDependencies(node, {y});
+ ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
+ node->set_op(IsAdd(*node) ? "Sub" : "Add");
+ node->set_input(1, y->input(0));
+ updated = true;
+ } else if (IsAdd(*node) && IsNeg(*x)) {
+ // (-a) + b = b - a
+ ForwardControlDependencies(node, {x});
+ ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
+ node->set_op("Sub");
+ node->mutable_input()->SwapElements(0, 1);
+ node->set_input(1, x->input(0));
+ updated = true;
}
if (updated) {
AddToOptimizationQueue(node);
@@ -2379,26 +2367,24 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
- for (int i = 0; i < p.shape().dim_size(); ++i) {
- if (p.shape().dim(i).size() < 0) {
+ const auto& pow_props =
+ ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
+ if (pow_props.shape().dim(i).size() < 0) {
// skip if p is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor pow(p.dtype(), p.shape());
- if (!pow.FromProto(p.value())) {
+ if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
+ Tensor pow(pow_props.dtype(), pow_props.shape());
+ if (!pow.FromProto(pow_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
+ pow_props.value().DebugString());
}
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- if (!GetElementUnexhaustive(pow, i,
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &curr)) {
+ if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip.
return Status::OK();
}
@@ -2411,12 +2397,19 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ const auto& value_props =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ const TensorShapeProto& output_shape =
+ ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
if (curr == complex128(2, 0)) {
node->set_op("Square");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(1, 0)) {
+ } else if (curr == complex128(1, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ // Pow could be used to broadcast, so make sure the shapes of the two
+ // arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
@@ -2426,20 +2419,20 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(0, 0)) {
- const auto& b =
- ctx().graph_properties->GetInputProperties(node->name())[0];
- for (int i = 0; i < b.shape().dim_size(); ++i) {
- if (b.shape().dim(i).size() < 0) {
+ } else if (curr == complex128(0, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ for (int i = 0; i < value_props.shape().dim_size(); ++i) {
+ if (value_props.shape().dim(i).size() < 0) {
// skip if b is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(b.shape()) && b.has_value()) {
- Tensor base(b.dtype(), b.shape());
- if (!base.FromProto(b.value())) {
+ if (TensorShape::IsValid(value_props.shape()) &&
+ value_props.has_value()) {
+ Tensor base(value_props.dtype(), value_props.shape());
+ if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- b.value().DebugString());
+ value_props.value().DebugString());
}
node->set_op("Const");
Tensor c(base.dtype(), base.shape());
@@ -2597,12 +2590,10 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
~ConvertExpm1Stage() override = default;
bool IsSupported(const NodeDef* node) const override {
- if (!IsSub(*node))
- return false;
+ if (!IsSub(*node)) return false;
NodeDef* input;
- if (!GetInputNode(node->input(0), &input).ok())
- return false;
+ if (!GetInputNode(node->input(0), &input).ok()) return false;
return IsExp(*input);
}
@@ -2622,10 +2613,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
return Status::OK();
}
- const auto& t =
- ctx().graph_properties->GetInputProperties(exp->name())[0];
- const auto& c =
- ctx().graph_properties->GetInputProperties(node->name())[1];
+ const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
+ const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
@@ -3053,6 +3042,13 @@ void ArithmeticOptimizer::DedupComputations() {
return;
}
std::set<int> duplicates;
+ // Populate feed_inplace_op;
+ std::unordered_set<NodeDef*> feeds_inplace_op;
+ for (int i = 0; i < optimized_graph_->node_size(); ++i) {
+ if (FeedsInPlaceOp(graph_view, optimized_graph_->node(i))) {
+ feeds_inplace_op.insert(optimized_graph_->mutable_node(i));
+ }
+ }
do {
stop = true;
UniqueNodes nodes;
@@ -3061,19 +3057,19 @@ void ArithmeticOptimizer::DedupComputations() {
continue;
}
NodeDef* node = optimized_graph_->mutable_node(i);
- if (!CanDedup(*node)) {
+ if (!CanDedup(*node) ||
+ feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
continue;
}
NodeDef* rep = nodes.FindOrAddRepresentative(node);
if (rep == node) {
continue;
}
- // If either node feeds an inplace op, deduping them may cause data races.
- // For example: If we dedup nodes initializing two independent inplace
- // accumulations, they will write to the same buffer, clobbering each
- // other's results.
- if (FeedsInPlaceOp(graph_view, *rep) ||
- FeedsInPlaceOp(graph_view, *node)) {
+ // If either node or rep feeds an inplace op, deduping them may cause data
+ // races. For example: If we dedup nodes initializing two independent
+ // inplace accumulations, they will write to the same buffer, clobbering
+ // each other's results.
+ if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
continue;
}
VLOG(3) << "Remove duplicated node: node=" << node->name()
@@ -3081,20 +3077,20 @@ void ArithmeticOptimizer::DedupComputations() {
const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
for (int i = 0; i < fanout->input_size(); ++i) {
- string* name = fanout->mutable_input(i);
- int position;
- const string nodename = ParseNodeName(*name, &position);
- if (nodename == node->name()) {
- // Update name in-place.
- if (position > 0) {
- *name = StrCat(rep->name(), ":", position);
- } else if (position == 0) {
- *name = rep->name();
- } else {
- *name = StrCat("^", rep->name());
- }
- node_map_->AddOutput(rep->name(), fanout->name());
+ string* fanout_input = fanout->mutable_input(i);
+ const int position =
+ NodePositionIfSameNode(*fanout_input, node->name());
+ // Update name in-place.
+ if (position < -1) {
+ continue;
+ } else if (position > 0) {
+ *fanout_input = StrCat(rep->name(), ":", position);
+ } else if (position == 0) {
+ *fanout_input = rep->name();
+ } else {
+ *fanout_input = StrCat("^", rep->name());
}
+ node_map_->AddOutput(rep->name(), fanout->name());
}
}
duplicates.insert(i);
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 39517edc06..77f3c64c65 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -581,7 +581,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -625,7 +625,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -2353,9 +2353,14 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
- auto add_all = ops::AddN(s.WithOpName("add_all"),
- {add_x_y, add_negx_y, add_x_negy, add_negx_negy,
- sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy});
+ Output neg_x_with_dep = ops::Neg(
+ s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
+ Output add_negx_with_dep_y =
+ ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
+ auto add_all =
+ ops::AddN(s.WithOpName("add_all"),
+ {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
+ sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
GrapplerItem item;
item.fetch = {"add_all"};
@@ -2370,7 +2375,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveNegation(&optimizer);
- OptimizeAndPrune(&optimizer, &item, &output);
+ OptimizeTwice(&optimizer, &item, &output);
EXPECT_EQ(item.graph.node_size(), output.node_size());
int found = 0;
@@ -2379,42 +2384,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
if (node.name() == "Add_negx_y") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
} else if (node.name() == "Add_x_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Add_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Neg_y", node.input(0));
- EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("Neg_x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
} else if (node.name() == "Sub_x_negy") {
++found;
EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Sub_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("x", node.input(1));
+ } else if (node.name() == "Add_negx_with_dep_y") {
+ ++found;
+ EXPECT_EQ("Sub", node.op());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
- EXPECT_EQ("^Neg_x", node.input(3));
+ EXPECT_EQ("^Add_x_y", node.input(2));
}
}
- EXPECT_EQ(5, found);
+ EXPECT_EQ(6, found);
auto tensors = EvaluateNodes(output, item.fetch, feed);
EXPECT_EQ(1, tensors.size());
@@ -2468,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
+ auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
+ auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
@@ -2475,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
Output out = ops::Pow(s.WithOpName("out"), x, y);
+ Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
+ Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
GrapplerItem item;
- item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+ item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
+ "out_1", "out", "out_bcast1", "out_bcast2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(7, tensors_expected.size());
+ EXPECT_EQ(9, tensors_expected.size());
GraphDef got;
ArithmeticOptimizer optimizer;
EnableOnlyConvertPow(&optimizer);
OptimizeAndPrune(&optimizer, &item, &got);
auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(7, tensors.size());
+ EXPECT_EQ(9, tensors.size());
- for (int i = 0; i < 7; ++i) {
+ for (int i = 0; i < tensors.size(); ++i) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
}
@@ -2503,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("y_.5", "Const", {}, {}, &want);
AddNode("y_1", "Const", {}, {}, &want);
AddNode("y", "Const", {}, {}, &want);
+ AddNode("z", "Const", {}, {}, &want);
+ AddNode("ones", "Const", {}, {}, &want);
+ AddNode("zeros", "Const", {}, {}, &want);
AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
@@ -2511,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
AddNode("out", "Pow", {"x", "y"}, {}, &want);
+ AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
+ AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
CompareGraphs(want, got);
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 99737a71eb..ca5d3a6dfd 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -32,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -437,25 +437,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
}
namespace {
-bool ShapesEqual(const TensorShapeProto& shape1,
- const TensorShapeProto& shape2) {
- if (shape1.unknown_rank() || shape2.unknown_rank()) {
- return false;
- }
- if (shape1.dim_size() != shape2.dim_size()) {
- return false;
- }
- for (int i = 0; i < shape1.dim_size(); ++i) {
- if (shape1.dim(i).size() != shape2.dim(i).size()) {
- return false;
- }
- if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
- return false;
- }
- }
- return true;
-}
-
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64* min_id) {
if (shape_node.op() == "Shape") {
@@ -2125,7 +2106,8 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
Tensor axis_t(DT_INT32, TensorShape({}));
NodeDef* axis_node = optimized_graph->add_node();
axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
- const int axis = node->attr().at("axis").i();
+ const int axis =
+ node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
.ok()) {
@@ -2348,7 +2330,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
- const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+ const bool y_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
@@ -2378,7 +2361,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
- const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+ const bool x_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 2a19b3f95a..b09360a2c2 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -3015,37 +3015,48 @@ TEST_F(ConstantFoldingTest, TrivialPack) {
auto stack =
ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
ops::Stack::Axis(1));
+ auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
- item.fetch.push_back("stack");
+ item.fetch = {"stack", "stack_no_axis"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(5, output.node_size());
+ EXPECT_EQ(7, output.node_size());
+ int found = 0;
for (const auto& node : output.node()) {
if (node.name() == "stack") {
- EXPECT_EQ("stack", node.name());
EXPECT_EQ("ExpandDims", node.op());
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
EXPECT_EQ("^y", node.input(2));
+ ++found;
+ } else if (node.name() == "stack_no_axis") {
+ EXPECT_EQ("ExpandDims", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
+ ++found;
} else if (node.name() == "ConstantFolding/stack_const_axis") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^x", node.input(0));
+ ++found;
}
}
+ EXPECT_EQ(found, 3);
- std::vector<string> fetch = {"stack"};
+ std::vector<string> fetch = {"stack", "stack_no_axis"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
auto tensors = EvaluateNodes(output, fetch);
- EXPECT_EQ(1, tensors_expected.size());
- EXPECT_EQ(1, tensors.size());
+ EXPECT_EQ(2, tensors_expected.size());
+ EXPECT_EQ(2, tensors.size());
EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
+ EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
}
// The test does not evalute the optimized and original graphs to check if their
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 530c957068..cf305cebe1 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -19,7 +19,6 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -50,14 +49,15 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":function_utils",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:functional_ops",
+ "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
] + tf_protos_all(),
@@ -68,6 +68,7 @@ tf_cc_test(
srcs = ["fusion_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":fusion_utils",
":graph_utils",
"//tensorflow/core:framework",
@@ -79,6 +80,40 @@ tf_cc_test(
)
cc_library(
+ name = "function_utils",
+ srcs = ["function_utils.cc"],
+ hdrs = [
+ "function_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "function_utils_test",
+ srcs = ["function_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ],
+)
+
+cc_library(
name = "graph_utils",
srcs = ["graph_utils.cc"],
hdrs = [
@@ -107,7 +142,6 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core/kernels:cast_op",
],
)
@@ -139,7 +173,9 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":graph_utils",
+ ":vectorization_utils",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:grappler_item",
@@ -164,7 +200,6 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work.
],
)
@@ -256,7 +291,6 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -275,6 +309,43 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/kernels:control_flow_ops",
+ ],
+)
+
+cc_library(
+ name = "map_parallelization",
+ srcs = ["map_parallelization.cc"],
+ hdrs = [
+ "map_parallelization.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_parallelization_test",
+ srcs = ["map_parallelization_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":map_parallelization",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
],
)
@@ -355,6 +426,7 @@ cc_library(
":map_and_batch_fusion",
":map_and_filter_fusion",
":map_fusion",
+ ":map_parallelization",
":map_vectorization",
":noop_elimination",
":shuffle_and_repeat_fusion",
@@ -375,3 +447,43 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
],
)
+
+cc_library(
+ name = "vectorization_utils",
+ srcs = ["vectorization_utils.cc"],
+ hdrs = [
+ "vectorization_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/optimizers/data/vectorization",
+ "//tensorflow/core/grappler/utils:functions",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "vectorization_utils_test",
+ srcs = ["vectorization_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":vectorization_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
new file mode 100644
index 0000000000..e95ea1a4c1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -0,0 +1,196 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+template <typename Predicate, typename Collection>
+std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ std::vector<int> indices = {};
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ indices.push_back(idx);
+ }
+ idx++;
+ }
+ return indices;
+}
+
+} // namespace
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
+ const string& output, int position)
+ : node_name(node_name), node_output(output), position(position) {
+ full_str = strings::StrCat(node_name, ":", node_output, ":", position);
+}
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
+ // Parses node_name:node_output:position string into its components.
+ full_str = input;
+ StringPiece capture;
+ StringPiece remaining;
+
+ // Parse "node_name"
+ if (strings::Scanner(input)
+ .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+ .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_name = string(capture.data(), capture.size());
+ }
+
+ // Parse "node_output" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .One(strings::Scanner::LETTER)
+ .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_output = string(capture.data(), capture.size());
+ }
+
+ // Parse "position" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .Many(strings::Scanner::DIGIT)
+ .GetResult(nullptr, &capture)) {
+ CHECK(strings::safe_strto32(capture, &position));
+ }
+}
+
+// TODO(rachelim): Create a utility class similar to MutableGraphView for
+// FunctionDefs, and use that to manipulate functions. It'll be more
+// performant if we kept mappings of nodes->inputs/outputs, so that we don't
+// have to search over all nodes each time.
+// Note that we're not using GrapplerFunctionItem because it doesn't cover
+// some of our desired uses (eg changing the outputs of a function), and the
+// FunctionDef -> GraphDef conversion isn't really necessary in this case.
+void ReplaceReferences(const string& from, const string& to,
+ FunctionDef* func) {
+ for (NodeDef& n : *func->mutable_node_def()) {
+ std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
+ to);
+ }
+
+ for (auto& p : *func->mutable_ret()) {
+ if (p.second == from) {
+ p.second = to;
+ }
+ }
+}
+
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt) {
+ string name = string(prefix);
+ int id = function->signature().output_arg_size();
+ while (ContainsFunctionOutputWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ auto* output = function->mutable_signature()->mutable_output_arg()->Add();
+ output->set_name(name);
+ output->set_type(dt);
+
+ (*function->mutable_ret())[name] = string(output_tensor_name);
+}
+
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd) {
+ NodeDef* node = fd->add_node_def();
+ if (!name.empty()) {
+ node->set_name(string(name));
+ } else {
+ SetUniqueFunctionNodeName(op, fd, node);
+ }
+ node->set_op(string(op));
+ for (const string& input : inputs) {
+ node->add_input(input);
+ }
+ for (auto attr : attributes) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+ return node;
+}
+
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionNodeWithName(name, function) != -1;
+}
+
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ return FindFunctionNodeWithOp(op, function) != -1;
+}
+
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionOutputWithName(name, function) != -1;
+}
+
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().input_arg());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().output_arg());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const NodeDef& node) { return node.name() == name; },
+ function.node_def());
+ return indices.empty() ? -1 : indices.front();
+}
+
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&op](const NodeDef& node) { return node.op() == op; },
+ function.node_def());
+
+ return indices.empty() ? -1 : indices.front();
+}
+
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node) {
+ string name = string(prefix);
+ int id = function->node_def_size();
+ while (ContainsFunctionNodeWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ node->set_name(std::move(name));
+}
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h
new file mode 100644
index 0000000000..d4ce824652
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+// This namespace contains utility functions for querying and modifying
+// FunctionDefs.
+
+// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
+// have the format node_name:node_output:position (if they derive from nodes),
+// or input_name (if they derive from an argument).
+struct FunctionDefTensorDesc {
+ FunctionDefTensorDesc() = default;
+
+ FunctionDefTensorDesc(const string& node_name, const string& output,
+ int position);
+
+ // Parses node_name:node_output:position string into its components.
+ explicit FunctionDefTensorDesc(const string& input);
+
+ // TODO(rachelim): Add provisions to deal with special formats, like how
+ // GrapplerFunctionItem expands node output range if position is not defined
+ string full_str;
+ string node_name;
+ string node_output;
+ int position = -1;
+};
+
+// Replaces all references to `from` tensor in func's nodes' inputs and retvals
+// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
+void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
+
+// Adds a function output to the function def, ensuring that the output key
+// is unique, and maps to output_tensor_name in the ret dict.
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt);
+
+// Adds a node to a FunctionDef.
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd);
+
+// Checks whether the function contains a node with the given name.
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Checks whether the function contains a node with the given op.
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Checks whether the function contains an output with the given name.
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Returns the index of the function input with the given name or -1 if the
+// function node does not exist.
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function output with the given name or -1 if the
+// function node does not exist.
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given name or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given op or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Sets the function node name using the `prefix` as a prefix while guaranteeing
+// the name is unique across the functions nodes.
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node);
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
new file mode 100644
index 0000000000..3739e20eb1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+TEST(FunctionDefTensorDesc, Parsing) {
+ FunctionDefTensorDesc f("Cast:y:0");
+ EXPECT_EQ(f.full_str, "Cast:y:0");
+ EXPECT_EQ(f.node_name, "Cast");
+ EXPECT_EQ(f.node_output, "y");
+ EXPECT_EQ(f.position, 0);
+
+ FunctionDefTensorDesc f2("Arg0");
+ EXPECT_EQ(f2.full_str, "Arg0");
+ EXPECT_EQ(f2.node_name, "Arg0");
+ EXPECT_EQ(f2.node_output, "");
+ EXPECT_EQ(f2.position, -1);
+}
+
+TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
+ {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
+ NodeDef* derive_node =
+ AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
+ // Check that both the input to "X" and retval of "outer" are replaced.
+ ReplaceReferences("MapDefun:output:0", "arg0", &outer);
+ EXPECT_EQ(outer.ret().at("out"), "arg0");
+ EXPECT_EQ(derive_node->input(0), "arg0");
+}
+
+TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
+ FunctionDef function = test::function::XTimesTwo();
+ AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
+ EXPECT_EQ(function.ret().at("y/_1"), "two");
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithName(
+ "weird_name_that_should_not_be_there", function));
+ EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
+ function));
+ EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
+ EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionInputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
+ EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
+ EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
+}
+
+TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
+ FunctionDef function = test::function::XTimesTwo();
+ NodeDef node;
+ SetUniqueFunctionNodeName("abc", &function, &node);
+ for (const NodeDef& function_node : function.node_def()) {
+ EXPECT_NE(node.name(), function_node.name());
+ }
+ auto* new_node = function.add_node_def();
+ *new_node = node;
+
+ NodeDef other;
+ SetUniqueFunctionNodeName("abc", &function, &other);
+ EXPECT_NE(other.name(), new_node->name());
+}
+
+TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
+ FunctionDef func;
+ const char* op_name = "xxx";
+ AddNode(op_name, op_name, {}, {}, &func);
+
+ const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
+ EXPECT_EQ(node1.op(), op_name);
+ EXPECT_EQ(node1.input_size(), 0);
+ EXPECT_EQ(node1.attr_size(), 0);
+
+ const std::vector<string> inputs({"input1", "input2"});
+ AddNode("", op_name, inputs, {}, &func);
+ const NodeDef& node2 =
+ func.node_def(FindFunctionNodeWithName("xxx/_2", func));
+ EXPECT_EQ(node2.op(), op_name);
+ EXPECT_EQ(node2.attr_size(), 0);
+ EXPECT_EQ(node2.input_size(), inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ EXPECT_EQ(node2.input(i), inputs[i]);
+ }
+
+ AttrValue a1, a2;
+ a1.set_type(DT_INT32);
+ a2.set_type(DT_INT64);
+ const std::vector<std::pair<string, AttrValue>> attrs(
+ {{"attr1", a1}, {"attr2", a2}});
+ AddNode("", op_name, {}, attrs, &func);
+ const NodeDef& node3 =
+ func.node_def(FindFunctionNodeWithName("xxx/_3", func));
+ EXPECT_EQ(node3.op(), op_name);
+ EXPECT_EQ(node3.input_size(), 0);
+ EXPECT_EQ(node3.attr_size(), attrs.size());
+ for (size_t i = 0; i < attrs.size(); ++i) {
+ EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
+ }
+}
+
+} // namespace
+} // namespace function_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index 01a78c04b0..b3bfee138f 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -407,7 +408,7 @@ void LazyConjunctionNodes(const FunctionDef& first_function,
auto* if_node = fused_function->add_node_def();
// This is guaranteed to succeed.
TF_CHECK_OK(if_builder.Finalize(if_node));
- graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+ function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
}
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index d5c6466080..e667affeea 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -110,9 +111,9 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
CheckUniqueNames(*fused_function);
ASSERT_TRUE(
- graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
const auto &equal_node = fused_function->node_def(
- graph_utils::FindFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
fused_function->signature().output_arg(0).name());
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 5a7fe19265..2dd9ee822e 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -88,6 +88,16 @@ NodeDef* AddScalarConstNodeHelper(
} // namespace
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
+ NodeDef node;
+ node.set_op("Placeholder");
+ SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node);
+ (*node.mutable_attr())["dtype"].set_type(dtype);
+ TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
+ shape->set_unknown_rank(false);
+ return graph->AddNode(std::move(node));
+}
+
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
@@ -108,26 +118,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
return graph->AddNode(std::move(node));
}
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd) {
- NodeDef* node = fd->add_node_def();
- if (!name.empty()) {
- node->set_name(string(name));
- } else {
- SetUniqueFunctionNodeName(op, fd, node);
- }
- node->set_op(string(op));
- for (const string& input : inputs) {
- node->add_input(input);
- }
- for (auto attr : attributes) {
- (*node->mutable_attr())[attr.first] = attr.second;
- }
- return node;
-}
-
template <>
NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
return AddScalarConstNodeHelper(
@@ -196,6 +186,11 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) {
return true;
}
+bool ContainsGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ return FindGraphFunctionWithName(name, library) != -1;
+}
+
bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
return FindGraphNodeWithName(name, graph) != -1;
}
@@ -204,18 +199,14 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
return FindGraphNodeWithOp(op, graph) != -1;
}
-bool ContainsGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- return FindGraphFunctionWithName(name, library) != -1;
-}
-
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function) {
- return FindFunctionNodeWithName(name, function) != -1;
-}
-
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- return FindFunctionNodeWithOp(op, function) != -1;
+int FindGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ std::vector<int> indices = GetElementIndicesWithPredicate(
+ [&name](const FunctionDef& function) {
+ return function.signature().name() == name;
+ },
+ library.function());
+ return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
@@ -237,31 +228,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
}
-int FindGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const FunctionDef& function) {
- return function.signature().name() == name;
- },
- library.function());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const NodeDef& node) { return node.name() == name; },
- function.node_def());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&op](const NodeDef& node) { return node.op() == op; },
- function.node_def());
-
- return indices.empty() ? -1 : indices.front();
-}
-
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
if (node.input_size() == 0) return nullptr;
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
@@ -273,7 +239,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
string name = string(prefix);
int id = graph->node_size();
while (ContainsGraphNodeWithName(name, *graph)) {
- if (name.rfind("_generated") != std::string::npos &&
+ if (name.rfind("_generated") != string::npos &&
(name.rfind("_generated") == (name.size() - strlen("_generated")))) {
name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
} else {
@@ -284,17 +250,6 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
node->set_name(std::move(name));
}
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node) {
- string name = string(prefix);
- int id = function->node_def_size();
- while (ContainsFunctionNodeWithName(name, *function)) {
- name = strings::StrCat(prefix, "/_", id);
- ++id;
- }
- node->set_name(std::move(name));
-}
-
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function) {
string name = string(prefix);
@@ -305,7 +260,6 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
}
function->mutable_signature()->set_name(std::move(name));
}
-
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 6f431c232d..b117482db2 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -37,11 +37,8 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
-// Adds a node to a FunctionDef.
-NodeDef* AddNode(StringPiece name, StringPiece op,
- const std::vector<string>& inputs,
- const std::vector<std::pair<string, AttrValue>>& attributes,
- FunctionDef* fd);
+// Adds Placeholder node for given type.
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph);
// Adds a Const node with the given value to the graph.
template <typename T>
@@ -76,13 +73,6 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph);
bool ContainsGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Checks whether the function contains a node with the given name.
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function);
-
-// Checks whether the function contains a node with the given op.
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Checks whether the graph contains a node with the given op.
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -95,14 +85,6 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph);
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Returns the index of the function node with the given name or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
-
-// Returns the index of the function node with the given op or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Returns the index of the first node with the given op or -1 if no such node
// exists.
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -119,11 +101,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the function node name using the `prefix` as a prefix while guaranteeing
-// the name is unique across the functions nodes.
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node);
-
// Sets the node name using the `prefix` name as a prefix while guaranteeing the
// name is unique across the graph.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index c19ac7b880..6877c207c4 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -112,20 +112,6 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
ContainsGraphFunctionWithName(new_function->signature().name(), library));
}
-TEST(GraphUtilsTest, ContainsFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithName(
- "weird_name_that_should_not_be_there", function));
- EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
-}
-
-TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
- function));
- EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
-}
-
TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -150,22 +136,6 @@ TEST(GraphUtilsTest, FindGraphNodeWithName) {
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
}
-TEST(GraphUtilsTest, FindFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
-}
-
-TEST(GraphUtilsTest, FindFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
-}
-
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FunctionDefLibrary library;
EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
@@ -225,21 +195,6 @@ TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
EXPECT_NE(node2->name(), node3->name());
}
-TEST(GraphUtilsTest, SetUniqueFunctionNodeName) {
- FunctionDef function = test::function::XTimesTwo();
- NodeDef node;
- SetUniqueFunctionNodeName("abc", &function, &node);
- for (const NodeDef& function_node : function.node_def()) {
- EXPECT_NE(node.name(), function_node.name());
- }
- auto* new_node = function.add_node_def();
- *new_node = node;
-
- NodeDef other;
- SetUniqueFunctionNodeName("abc", &function, &other);
- EXPECT_NE(other.name(), new_node->name());
-}
-
TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
FunctionDefLibrary library;
FunctionDef* new_function = library.add_function();
@@ -251,43 +206,6 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
other_function->signature().name());
}
-TEST(GraphUtilsTest, AddNodeToFunctionDef) {
- FunctionDef func;
- const char* op_name = "xxx";
- AddNode(op_name, op_name, {}, {}, &func);
-
- const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
- EXPECT_EQ(node1.op(), op_name);
- EXPECT_EQ(node1.input_size(), 0);
- EXPECT_EQ(node1.attr_size(), 0);
-
- const std::vector<string> inputs({"input1", "input2"});
- AddNode("", op_name, inputs, {}, &func);
- const NodeDef& node2 =
- func.node_def(FindFunctionNodeWithName("xxx/_2", func));
- EXPECT_EQ(node2.op(), op_name);
- EXPECT_EQ(node2.attr_size(), 0);
- EXPECT_EQ(node2.input_size(), inputs.size());
- for (size_t i = 0; i < inputs.size(); ++i) {
- EXPECT_EQ(node2.input(i), inputs[i]);
- }
-
- AttrValue a1, a2;
- a1.set_type(DT_INT32);
- a2.set_type(DT_INT64);
- const std::vector<std::pair<string, AttrValue>> attrs(
- {{"attr1", a1}, {"attr2", a2}});
- AddNode("", op_name, {}, attrs, &func);
- const NodeDef& node3 =
- func.node_def(FindFunctionNodeWithName("xxx/_3", func));
- EXPECT_EQ(node3.op(), op_name);
- EXPECT_EQ(node3.input_size(), 0);
- EXPECT_EQ(node3.attr_size(), attrs.size());
- for (size_t i = 0; i < attrs.size(); ++i) {
- EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
- }
-}
-
TEST(GraphUtilsTest, GetInputNode) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
new file mode 100644
index 0000000000..305325e434
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -0,0 +1,106 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+bool CanParallelize(const FunctionDef& function,
+ const FunctionLibraryDefinition& library) {
+ if (!function.signature().is_stateful()) return true;
+
+ for (const auto& node : function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Assert is marked as stateful, but it does not have any state (except
+ // changing io). Similarly to CUDA, we do not give guarantee that the
+ // assert operation that would fail would be the first one, so that we can
+ // parallelize it.
+ if (op_def->is_stateful() && op_def->name() != "Assert") return false;
+ }
+
+ return true;
+}
+
+NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
+ NodeDef parallel_map = map_node;
+ graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(),
+ &parallel_map);
+ parallel_map.set_op("ParallelMapDataset");
+ // TODO(b/114475558): We want to set `num_parallel_calls` to a special value,
+ // so that dynamic tunning will pick the optimal value at runtime. Because
+ // this feature is not yet implemented, we set it to 2, which is the smallest
+ // value that introduces parallelism.
+ auto* num_parallel_calls = graph_utils::AddScalarConstNode(2, graph);
+ parallel_map.add_input(num_parallel_calls->name());
+
+ return parallel_map;
+}
+
+} // namespace
+
+Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ auto* function =
+ function_library.Find(map_node->attr().at("f").func().name());
+ if (!CanParallelize(*function, function_library)) continue;
+
+ auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
+ graph.ReplaceInput(*map_node, *parallel_map);
+
+ // TODO(prazek): we could also remove map functions from library if they
+ // are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void MapParallelization::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
new file mode 100644
index 0000000000..ac9cf7e12a
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization parallelizes MapDataset when function is stateless.
+class MapParallelization : public CustomGraphOptimizer {
+ public:
+ MapParallelization() = default;
+ ~MapParallelization() override = default;
+
+ string name() const override { return "map_parallelization"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
new file mode 100644
index 0000000000..b2a5d9b6af
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -0,0 +1,94 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", {}},
+ {"output_types", {}}});
+}
+
+const char stateless_fun_name[] = "XTimesTwo";
+const char stateful_fun_name[] = "RandomUniform";
+
+TEST(MapParallelizationTest, ParallelizeSimpleMap) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeMapNode("map1", "range", stateless_fun_name)},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapParallelization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+TEST(MapParallelization, ParallelizeAssert) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeMapNode("map1", "range", stateful_fun_name),
+ MakeMapNode("map2", "map1", stateless_fun_name),
+ NDef("cache", "CacheDataset", {"map2", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ test::function::RandomUniform(),
+ });
+
+ MapParallelization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index a019b77eb7..7a2f1910da 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -37,11 +39,11 @@ void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
(*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
}
-FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+// Returns a FunctionDef containing a MapDefun op that wraps the original
+// function.
+FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
const FunctionDef& orig_func,
FunctionDefLibrary* library) {
- // If we decide to use a different method of vectorization, we can just
- // swap out this part.
FunctionDef* vectorized_func = library->add_function();
// Function inputs and outputs are the same as original, just
// with different shapes.
@@ -52,8 +54,8 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// Add MapDefun node
NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
map_defun_node->set_op("MapDefun");
- graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func,
- map_defun_node);
+ function_utils::SetUniqueFunctionNodeName(map_defun_node->op(),
+ vectorized_func, map_defun_node);
// Set attrs and inputs
for (const string& k : {"f", "output_types", "output_shapes"}) {
@@ -81,6 +83,30 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
return vectorized_func;
}
+FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+ const FunctionDef& orig_func,
+ FunctionDefLibrary* library) {
+ // Vectorizes orig_func naively by wrapping in a MapDefun op, then performing
+ // efficient vectorization with VectorizeMapDefun.
+ FunctionDef* vectorized_func =
+ CreateMapDefunWrapper(map_node, orig_func, library);
+ NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
+ DCHECK_EQ(map_defun_node->op(), "MapDefun");
+
+ // Create a copy of the original function so that we can mutate it, and
+ // attach that to the map defun node.
+ FunctionDef* map_defun_fn = library->add_function();
+ *map_defun_fn = orig_func;
+ graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
+ map_defun_fn);
+ (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
+ map_defun_fn->signature().name());
+
+ vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
+ map_defun_node);
+ return vectorized_func;
+}
+
bool IsOutputShapesFullyDefined(const NodeDef& node) {
auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
if (shapes_attr == nullptr) return false;
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
index a26f1000a3..cf5a19bab1 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
@@ -33,25 +33,27 @@ namespace {
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
if (take_node.op() != "TakeDataset") return false;
- const NodeDef& count_node = *graph.GetNode(take_node.input(1));
+ const auto& count_node = *graph.GetNode(take_node.input(1));
+ if (count_node.op() != "Const") return false;
// We are looking only for 'take' with negative count.
return count_node.attr().at("value").tensor().int64_val(0) < 0;
}
+bool IsConstNodeWithValue(const NodeDef& node, int value) {
+ if (node.op() != "Const") return false;
+ return node.attr().at("value").tensor().int64_val(0) == value;
+}
+
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
if (skip_node.op() != "SkipDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(skip_node.input(1));
// We are looking only for skip(0) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 0;
+ return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
}
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
if (repeat_node.op() != "RepeatDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(repeat_node.input(1));
// We are looking only for repeat(1) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 1;
+ return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
}
bool IsNoOp(const NodeDef& node, const GraphView& graph) {
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
index f445e75aa7..be1a66df75 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
@@ -43,6 +43,14 @@ NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node,
GetCommonAttributes(), graph);
}
+NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node,
+ MutableGraphView *graph) {
+ NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph);
+ return graph_utils::AddNode("", node_type,
+ {std::move(input_node), node_count->name()},
+ GetCommonAttributes(), graph);
+}
+
NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) {
NodeDef *node_filename =
graph_utils::AddScalarConstNode<StringPiece>("", graph);
@@ -205,6 +213,41 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(*kTakeNode, *kSkipNode,
*kRepeatNode)));
+struct NoOpPlaceholdersTest
+ : ::testing::TestWithParam<std::tuple<string, string>> {};
+
+TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) {
+ GrapplerItem item;
+ MutableGraphView graph(&item.graph);
+
+ static_assert(std::tuple_size<NodesTypes>::value == 2,
+ "Make sure to include everything in the test");
+ const std::vector<string> noop_nodes = {std::get<0>(GetParam()),
+ std::get<1>(GetParam())};
+ NodeDef *range_node = MakeRangeNode(&graph);
+ std::vector<string> nodes_to_keep;
+ nodes_to_keep.reserve(noop_nodes.size());
+ NodeDef *previous = range_node;
+
+ for (const auto &noop_node : noop_nodes) {
+ NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph);
+ nodes_to_keep.push_back(node->name());
+ previous = node;
+ }
+
+ NoOpElimination optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ for (const auto &noop_node_name : nodes_to_keep)
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DoNotRemovePlaceholders, NoOpPlaceholdersTest,
+ ::testing::Combine(
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"),
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset")));
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
new file mode 100644
index 0000000000..1462cb234d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -0,0 +1,69 @@
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+
+VECTORIZER_DEPS = [
+ ":vectorizer_registry",
+ "//tensorflow/core/grappler/optimizers/data:function_utils",
+] + tf_protos_all()
+
+cc_library(
+ name = "vectorizer",
+ hdrs = ["vectorizer.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "vectorizer_registry",
+ srcs = ["vectorizer_registry.cc"],
+ hdrs = ["vectorizer_registry.h"],
+ deps = [
+ ":vectorizer",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "cast_vectorizer",
+ srcs = ["cast_vectorizer.cc"],
+ deps = VECTORIZER_DEPS,
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "unpack_vectorizer",
+ srcs = ["unpack_vectorizer.cc"],
+ deps = VECTORIZER_DEPS,
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "vectorization",
+ hdrs = ["vectorizer_registry.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cast_vectorizer",
+ ":unpack_vectorizer",
+ ":vectorizer",
+ ":vectorizer_registry",
+ ],
+)
+
+tf_cc_test(
+ name = "vectorizer_registry_test",
+ srcs = ["vectorizer_registry_test.cc"],
+ deps = [
+ ":vectorizer_registry",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
new file mode 100644
index 0000000000..c1739737a0
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class CastVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ if (inputs.size() != 1) {
+ return errors::Internal("Cast op should only have one input.");
+ }
+
+ // Add new Cast node
+ NodeDef* new_cast_node = outer_scope->add_node_def();
+ *new_cast_node = node;
+ new_cast_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", node.name()), outer_scope,
+ new_cast_node);
+ new_cast_node->set_input(0, inputs[0]);
+
+ // Add the output mapping to conversion map
+ (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
+ strings::StrCat(new_cast_node->name(), ":y:0");
+
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("Cast", CastVectorizer);
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
new file mode 100644
index 0000000000..776d3179c5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -0,0 +1,61 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class UnpackVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ if (inputs.size() != 1) {
+ return errors::Internal("Unpack op should only have one input.");
+ }
+
+ // Add new Unpack node
+ NodeDef* new_unpack_node = outer_scope->add_node_def();
+ *new_unpack_node = node;
+ new_unpack_node->clear_name();
+ function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("vectorized/", node.name()), outer_scope,
+ new_unpack_node);
+
+ // Increment "axis" attr by 1:
+ (*new_unpack_node->mutable_attr())["axis"].set_i(
+ node.attr().at("axis").i() + 1);
+ new_unpack_node->set_input(0, inputs[0]);
+
+ // Add the output mappings to conversion map
+ int num = new_unpack_node->attr().at("num").i();
+ for (int i = 0; i < num; ++i) {
+ (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
+ strings::StrCat(new_unpack_node->name(), ":output:", i);
+ }
+
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
new file mode 100644
index 0000000000..d341dbba7d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
+// for an example.
+class Vectorizer {
+ public:
+ virtual ~Vectorizer() {}
+
+ // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // that produce the same vector output(s) as executing `node`'s op
+ // on elements of the vector inputs, and adding mappings to `conversion_map`
+ // from old output tensor names to new (vectorized) output tensor names.
+ // The new node(s) collectively have the same number of inputs and outputs as
+ // the node being converted, and use the tensor names in `inputs` as their
+ // inputs.
+ virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) = 0;
+};
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
new file mode 100644
index 0000000000..a6551e36ac
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+VectorizerRegistry* VectorizerRegistry::Global() {
+ static VectorizerRegistry* registry = new VectorizerRegistry;
+ return registry;
+}
+
+Vectorizer* VectorizerRegistry::Get(const string& op_type) {
+ auto found = vectorizers_.find(op_type);
+ if (found == vectorizers_.end()) {
+ return nullptr;
+ }
+ return found->second.get();
+}
+
+void VectorizerRegistry::Register(const string& op_type,
+ std::unique_ptr<Vectorizer> vectorizer) {
+ auto existing = Get(op_type);
+ CHECK_EQ(existing, nullptr)
+ << "Vectorizer for op type: " << op_type << " already registered";
+ vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
+ op_type, std::move(vectorizer)));
+}
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
new file mode 100644
index 0000000000..16159d47ca
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
+
+#include <functional>
+#include <map>
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// A global VectorizerRegistry is used to hold all the vectorizers.
+class VectorizerRegistry {
+ public:
+ // Returns a pointer to a global VectorizerRegistry object.
+ static VectorizerRegistry* Global();
+
+ // Returns a pointer to a vectorizer that can vectorize an op for the op type.
+ Vectorizer* Get(const string& op_type);
+
+ // Registers a vectorizer that can vectorize an op for the given op type.
+ void Register(const string& op_type, std::unique_ptr<Vectorizer> vectorizer);
+
+ private:
+ std::map<string, std::unique_ptr<Vectorizer>> vectorizers_;
+};
+
+namespace vectorizer_registration {
+
+class VectorizerRegistration {
+ public:
+ VectorizerRegistration(const string& op_type,
+ std::unique_ptr<Vectorizer> vectorizer) {
+ VectorizerRegistry::Global()->Register(op_type, std::move(vectorizer));
+ }
+};
+
+} // namespace vectorizer_registration
+
+#define REGISTER_VECTORIZER(op_type, vectorizer) \
+ REGISTER_VECTORIZER_UNIQ_HELPER(__COUNTER__, op_type, vectorizer)
+
+#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
+ REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
+
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorization_utils:: \
+ vectorizer_registration::VectorizerRegistration \
+ vectorizer_registration_##ctr( \
+ op_type, \
+ ::std::unique_ptr< \
+ ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
+ new vectorizer()))
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
new file mode 100644
index 0000000000..86e303564b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class TestVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
+ FunctionDef* outer_scope,
+ std::map<string, string>* conversion_map) override {
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("test_op", TestVectorizer);
+
+TEST(TestVectorizer, TestTestVectorizer) {
+ EXPECT_EQ(VectorizerRegistry::Global()->Get("nonexistent"), nullptr);
+
+ auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
+ EXPECT_NE(vectorizer, nullptr);
+
+ FunctionDef function;
+ NodeDef node;
+ std::map<string, string> conversion_map;
+ EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+}
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
new file mode 100644
index 0000000000..cb56b65985
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -0,0 +1,292 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+using function_utils::FunctionDefTensorDesc;
+
+namespace {
+
+void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
+ const string& output_retval, const DataType t) {
+ // Set to unknown shape
+ TensorShapeProto tensor_shape_proto;
+ PartialTensorShape().AsProto(&tensor_shape_proto);
+
+ function_utils::AddFunctionOutputWithUniqueName(
+ "vectorized_out", output_retval, map_defun_fn, t);
+
+ *(*map_defun_node->mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape() = tensor_shape_proto;
+ (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+}
+
+void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node, int output_position) {
+ DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs =
+ map_defun_fn->signature().output_arg_size() - output_position - 1;
+
+ // Remove from map_defun_fn's ret dict and output args
+ map_defun_fn->mutable_ret()->erase(
+ map_defun_fn->signature().output_arg(output_position).name());
+ map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
+ output_position, 1);
+
+ // Renumber outputs that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i + 1),
+ strings::StrCat(map_defun_node->name(),
+ ":output:", output_position + i),
+ outer_scope);
+ }
+ map_defun_node->mutable_attr()
+ ->at("output_shapes")
+ .mutable_list()
+ ->mutable_shape()
+ ->DeleteSubrange(output_position, 1);
+ map_defun_node->mutable_attr()
+ ->at("output_types")
+ .mutable_list()
+ ->mutable_type()
+ ->ExtractSubrange(output_position, 1, nullptr);
+}
+
+int FindOutputToConvert(const FunctionDef& function,
+ const std::set<string>& unconvertible,
+ FunctionDefTensorDesc* f) {
+ for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
+ const string& ret_key = function.signature().output_arg(i).name();
+ *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+
+ if (unconvertible.find(f->node_name) == unconvertible.end()) {
+ return i;
+ }
+ }
+ return -1;
+}
+
+// Helper class that vectorizes the body of a MapDefun node, adding new
+// operations to the graph that collectively compute the same value as what
+// running the MapDefun function on slices of the input would produce.
+// Each instance of the class encapsulates all the data necessary to vectorize a
+// MapDefun op in place.
+class Vectorization {
+ public:
+ Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node)
+ : outer_scope_(outer_scope),
+ map_defun_fn_(map_defun_fn),
+ map_defun_node_(map_defun_node) {}
+
+ // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
+ // the outer_scope_, until there are no convertible outputs remaining.
+ // This method is idempotent.
+ void Vectorize();
+
+ private:
+ // Vectorizes the map defun function's output at output_position
+ Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
+ // Given a descriptor of the original output tensor, gets a string
+ // corresponding to the converted output tensor.
+ Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
+ string* converted);
+ Status AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc);
+
+ // Adds mappings from node's outputs tensors to converted output tensors,
+ // creating the necessary new node(s). Generally, the steps to convert an op
+ // are:
+ // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
+ // and modify map_defun_node_ attrs accordingly
+ // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // These operations collectively compute the same value as what running
+ // the original operation on slices of the input tensors would produce.
+ // For example, a Cast op in MapDefun translates to a Cast op in
+ // outer_scope_, since the vectorized version of Cast is itself.
+ // 3) Set inputs of new node(s) to the corresponding converted inputs (that
+ // are now outputs of map_defun_node_)
+ // 4) For each output of the old node, add the mapping of output strings to
+ // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
+ Status AddConversionMappingFromOp(const NodeDef& node,
+ const FunctionDefTensorDesc& output_desc);
+
+ // Maps a tensor name to the name of the corresponding vectorized tensor. For
+ // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
+ std::map<string, string> conversion_map_;
+ // Unconvertible node names
+ std::set<string> unconvertible_;
+
+ FunctionDef* outer_scope_;
+ FunctionDef* map_defun_fn_;
+ NodeDef* map_defun_node_;
+};
+
+Status Vectorization::AddConversionMappingFromOp(
+ const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
+ for (const string& input_name : node.input()) {
+ if (IsControlInput(input_name)) {
+ return errors::InvalidArgument(
+ "Vectorizing outputs with control inputs is currently not "
+ "supported.");
+ }
+ }
+
+ // TODO(rachelim): Have some mechanism for registering converters and some
+ // uniform, simpler way to represent them.
+
+ DataTypeVector types;
+ const OpDef* op_def = nullptr;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
+ TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
+
+ std::vector<string> promoted_inputs;
+ promoted_inputs.reserve(node.input_size());
+ for (int i = 0; i < node.input_size(); ++i) {
+ promoted_inputs.push_back(strings::StrCat(
+ map_defun_node_->name(),
+ ":output:", map_defun_fn_->signature().output_arg_size() + i));
+ }
+
+ auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ if (vectorizer == nullptr) {
+ return errors::Unimplemented("No vectorizer registered for op: ",
+ node.op());
+ }
+
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
+ &conversion_map_));
+
+ // If we get here, the conversion was successful, so we promote the inputs
+ // of the ops to MapDefun outputs.
+ for (int i = 0; i < types.size(); ++i) {
+ AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ }
+
+ return Status::OK();
+}
+
+Status Vectorization::AddConversionMappingFromInput(
+ const FunctionDefTensorDesc& output_desc) {
+ int input_index = function_utils::FindFunctionInputWithName(
+ output_desc.node_name, *map_defun_fn_);
+ if (input_index == -1) {
+ return errors::Internal("Cannot convert non-existent input.");
+ }
+
+ conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutputHelper(
+ const FunctionDefTensorDesc& output_desc, string* converted) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
+ *converted = *found;
+ return Status::OK();
+ }
+
+ int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
+ *map_defun_fn_);
+ if (index == -1) { // The output comes from an input
+ TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
+ map_defun_fn_->node_def(index), output_desc));
+ }
+ *converted = conversion_map_.at(output_desc.full_str);
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position,
+ const FunctionDefTensorDesc& output_desc) {
+ string converted_output_name;
+ TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+
+ // Remove the old output and make everything that referenced it point
+ // to the new string
+ function_utils::ReplaceReferences(
+ strings::StrCat(map_defun_node_->name(), ":output:", output_position),
+ converted_output_name, outer_scope_);
+ RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
+ output_position);
+
+ return Status::OK();
+}
+
+void Vectorization::Vectorize() {
+ while (true) {
+ FunctionDefTensorDesc desc;
+ int output_position =
+ FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
+ if (output_position == -1) break;
+
+ if (!ConvertOutput(output_position, desc).ok()) {
+ unconvertible_.insert(desc.node_name);
+ }
+ }
+
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->signature().output_arg_size() == 0) {
+ outer_scope_->mutable_node_def()->DeleteSubrange(
+ function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
+ *outer_scope_),
+ 1);
+ }
+
+ if (!unconvertible_.empty()) {
+ VLOG(2) << "The following nodes could not be converted: ["
+ << absl::StrJoin(unconvertible_, ", ") << "].";
+ }
+}
+} // namespace
+
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node) {
+ Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+}
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
new file mode 100644
index 0000000000..bb405faa77
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Given a function, `map_defun_fn`, that is mapped across some input vector
+// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
+// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
+// `outer_scope`; that is, replacing `map_defun_fn` operations with new
+// `outer_scope` operations that produce the same vector output(s) as executing
+// the `map_defun_fn` operations on elements of vector input(s) would. If all
+// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
+// eliminated from `outer_scope` altogether. However, if some operations cannot
+// be lifted, and this vectorization only succeeds partially, `map_defun_node`
+// remains to be used for operations that were not lifted.
+//
+// Example:
+// If the input to the `VectorizeMapDefun` function is a MapDefun
+// whose `map_defun_fn` performs the Cast operation, the vectorization will
+// eliminate the MapDefun. This is because the Cast operation supports
+// any tensor shape and can thus be lifted to the `outer_scope`.
+//
+// Before:
+//
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | map_defun_fn +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
+ NodeDef* map_defun_node);
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
new file mode 100644
index 0000000000..e129fa9237
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -0,0 +1,600 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+namespace {
+
+NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs,
+ DataType src, DataType dst, bool truncate,
+ FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("SrcT", src, node);
+ graph_transforms::SetNodeAttr("DstT", dst, node);
+ graph_transforms::SetNodeAttr("Truncate", truncate, node);
+ return node;
+}
+
+NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs,
+ DataType t, int axis, int num, FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("T", t, node);
+ graph_transforms::SetNodeAttr("axis", axis, node);
+ graph_transforms::SetNodeAttr("num", num, node);
+ return node;
+}
+
+NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
+ const std::vector<DataType>& t_arguments,
+ const std::vector<DataType>& output_types,
+ const std::vector<TensorShape>& output_shapes,
+ const string& function_name, FunctionDef* fn) {
+ NameAttrList func;
+ func.set_name(function_name);
+ NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("output_types", output_types, node);
+ graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
+ graph_transforms::SetNodeAttr("f", func, node);
+ return node;
+}
+
+// TODO(rachelim): Use FunctionDefHelper::Create instead
+FunctionDef CreateFunction(
+ StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
+ const std::vector<std::pair<string, DataType>>& outputs,
+ const std::map<string, string>& rets) {
+ FunctionDef func;
+ auto* signature = func.mutable_signature();
+ signature->set_name(string(name));
+ for (const auto& x : inputs) {
+ auto* arg_def = signature->add_input_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : outputs) {
+ auto* arg_def = signature->add_output_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : rets) {
+ (*func.mutable_ret())[x.first] = x.second;
+ }
+
+ return func;
+}
+
+TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | | | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "arg0"}, {"ret1", "arg1"}});
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
+ EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | +------+ | +---v--+ | |
+// | | |Const | | | Op0 | | |
+// | | +---v--+ | +---+--+ | |
+// | | | | | | |
+// | | | +---v--+ +---v--+ | |
+// | | +---| XOp1 | | XOp2 | | |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+// where XOp1 and XOp2 are not convertible.
+//
+// After:
+//
+// No change because the ops are not convertible.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ NodeDef* x_op1 =
+ function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(x_op1);
+
+ NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
+ CHECK_NOTNULL(x_op2);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +---------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +----------+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +----------+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
+ // Tests that behavior is correct when an output is used more than once.
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}, {"ret1", DT_INT64}},
+ {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), "x");
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | +---+--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---+--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
+ const NodeDef& cast_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ EXPECT_EQ(cast_node.input(0), "x");
+ const NodeDef& unpack_node =
+ outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+
+ EXPECT_EQ(outer.ret().at("mapdefun"),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(outer.node_def_size(), 2);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | +---------+ | |
+// | | +---v--+ | | |
+// | | |Print | | | |
+// | | +---+--+ | | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// No change because we don't deal with control inputs for now.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ // The attrs aren't relevant
+ NodeDef* print_op =
+ function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(print_op);
+ NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
+ false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDef outer_copy(outer);
+ FunctionDef inner_copy(inner);
+ VectorizeMapDefun(&outer, &inner, map_defun);
+ // They should be unchanged
+ EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+}
+
+// TODO(rachelim): More test cases when we get around to implementing them:
+// [] A badly defined converter, e.g. doesn't produce nodes that have the
+// same number of outputs/inputs as the nodes to be converted
+// [] Converter where the 'converted' form has multiple nodes.
+// [] Case with dependent nodes, e.g. ops with const inputs that are
+// broadcasted.
+// [] Python-side tests to actually run the functions to make sure
+// they work.
+
+} // namespace
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
index 9701a038d0..800160e649 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -38,7 +38,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// be optimized away by dependency optimizer.
for (string& inp : *node.mutable_input()) {
if (!IsControlInput(inp)) {
- inp = AsControlDependency(inp);
+ inp = AsControlDependency(NodeName(inp));
}
}
} else if (IsCheckNumerics(node) || IsPrint(node)) {
@@ -54,7 +54,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// input.
for (size_t i = 1; i < node.input_size(); ++i) {
if (!IsControlInput(node.input(i))) {
- *node.mutable_input(i) = AsControlDependency(node.input(i));
+ *node.mutable_input(i) = AsControlDependency(NodeName(node.input(i)));
}
}
}
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index 96ceee791f..affd2d51c2 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -43,6 +43,35 @@ TEST_F(DebugStripperTest, OutputEqualToInput) {
CompareGraphs(item.graph, output);
}
+TEST_F(DebugStripperTest, StripAssertOnTwoOutputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape({6}));
+ auto split =
+ ops::Split(s.WithOpName("split"), /*axis=*/0, input, /*num_split=*/2);
+ Output x = split[0];
+ Output y = split[1];
+ Output ge = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto assert = ops::Assert(s.WithOpName("Assert"), ge, {x, y});
+ Output add = ops::Add(
+ s.WithOpName("add").WithControlDependencies({assert.operation}), x, y);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ for (const NodeDef& node : output.node()) {
+ for (const string& input : node.input()) {
+ if (IsControlInput(input)) {
+ EXPECT_EQ(input.find(':'), -1);
+ }
+ }
+ }
+}
+
TEST_F(DebugStripperTest, StripAssertFromGraph) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
index eeea269fb0..2c36c9b7b3 100644
--- a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
@@ -32,8 +32,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-REGISTER_GRAPH_OPTIMIZER(ExperimentalImplementationSelector);
-
Status ExperimentalImplementationSelector::LoadFunctions(
const GraphDef& graph) {
lib_info_.reset(new FunctionLibraryApiInfo);
@@ -43,8 +41,20 @@ Status ExperimentalImplementationSelector::LoadFunctions(
Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
NodeDef* node_def) const {
- const FunctionApiInfo* info = lib_info_->GetApiInfo(node_def->op());
- if (info == nullptr) {
+ // There are two ways of calling functions:
+ // 1. By specifying an op name as a function name, or
+ // 2. Via the @defun functional interface, where the real function name
+ // appear as the attribute with type func.
+ std::vector<string> function_attribute_names;
+ for (const auto& attr : node_def->attr()) {
+ if (attr.second.has_func() &&
+ lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
+ function_attribute_names.emplace_back(attr.first);
+ }
+ }
+
+ if (function_attribute_names.empty() &&
+ lib_info_->GetApiInfo(node_def->op()) == nullptr) {
// A regular op, or a function which has no interface.
return Status::OK();
}
@@ -58,17 +68,25 @@ Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
DeviceNameUtils::ParsedName parsed_name;
DeviceNameUtils::ParseLocalName(device, &parsed_name);
- string best_function_name;
- lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
- &best_function_name);
- if (node_def->op() != best_function_name) {
- // The current implementation is not the best, swap the op to the best one.
- // There will be duplicates in the graph and they will be pruned by other
- // grappler plugin since no other node is using their output as inputs.
- // TODO(scottzhu): Update the tf.eager.defun to register functions without
- // having to call them with input data. That will reduce the graph size and
- // save the work for prune them.
- node_def->set_op(best_function_name);
+ for (const auto& attr_name : function_attribute_names) {
+ string function_name = node_def->attr().at(attr_name).func().name();
+ string best_function_name;
+ lib_info_->GetBestImplementation(function_name, parsed_name.type,
+ &best_function_name);
+ if (function_name != best_function_name) {
+ node_def->mutable_attr()
+ ->find(attr_name)
+ ->second.mutable_func()
+ ->set_name(best_function_name);
+ }
+ }
+ if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
+ string best_function_name;
+ lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
+ &best_function_name);
+ if (node_def->op() != best_function_name) {
+ node_def->set_op(best_function_name);
+ }
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
index 2368e577c2..3f1ebefac6 100644
--- a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -45,9 +45,8 @@ TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) {
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
- std::unique_ptr<CustomGraphOptimizer> optimizer =
- CustomGraphOptimizerRegistry::CreateByNameOrNull(
- "ExperimentalImplementationSelector");
+ std::unique_ptr<CustomGraphOptimizer> optimizer(
+ new ExperimentalImplementationSelector);
ASSERT_NE(nullptr, optimizer);
TF_ASSERT_OK(optimizer->Init());
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index b75d6303b4..c59645e5f2 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -23,11 +23,13 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
@@ -104,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
+ MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -132,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
+ if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) {
+ optimizers->push_back(MakeUnique<PinToHostOptimizer>());
+ }
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->push_back(
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
@@ -196,14 +202,34 @@ Status MetaOptimizer::InitializeOptimizersByName(
Status MetaOptimizer::InitializeCustomGraphOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
- auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
- optimizer_config.name());
+ // Initialize the ExperimentalImplementationSelector here instead of
+ // CustomizeOptimizer registry, due the static link issue in TensorRT for
+ // double registry.
+ // TODO(laigd): Remove this hack and change it back to use the registry once
+ // the duplicate static import issue is fixed.
+ std::unique_ptr<CustomGraphOptimizer> custom_optimizer;
+ if (optimizer_config.name() == "ExperimentalImplementationSelector") {
+ custom_optimizer.reset(new ExperimentalImplementationSelector());
+ } else {
+ custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
+ optimizer_config.name());
+ }
if (custom_optimizer) {
VLOG(2) << "Registered custom configurable graph optimizer: "
<< optimizer_config.name();
TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config));
optimizers->push_back(std::move(custom_optimizer));
} else {
+ // If there are no custom optimizers with given name, try to initalize a
+ // default optimizer. This way, custom configurable optimizers can be
+ // mixed with default optimizers in any order.
+ auto optimizer = MakeNewOptimizer(optimizer_config.name());
+ if (optimizer) {
+ VLOG(2) << "Registered default graph optimizer: "
+ << optimizer_config.name();
+ optimizers->push_back(std::move(optimizer));
+ continue;
+ }
VLOG(2) << "Can't register an optimizer by name: "
<< optimizer_config.name();
}
@@ -341,10 +367,12 @@ Status MetaOptimizer::RunOptimizer(
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
+ VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
+ VLOG(1) << "Optimized main graph.";
// Skip optimizing functions if this is a TPU graph. Currently, Grappler
// passes do not handle TPU functions correctly in a variety of ways (Note
@@ -421,7 +449,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
- VLOG(3) << "Optimized " << optimized_funcs.size()
+ VLOG(1) << "Optimized " << optimized_funcs.size()
<< " functions: " << str_util::Join(optimized_funcs, ", ");
return Status::OK();
@@ -455,6 +483,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
+ cfg.pin_to_host_optimization() != RewriterConfig::OFF ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
new file mode 100644
index 0000000000..2190d38937
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -0,0 +1,264 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+
+// TODO(williamchan): Change this constant to be something smarter, maybe
+// dynamically determined.
+constexpr int64 kTensorMaxSize = 64;
+
+// Find KernelDef for `node`.
+Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
+ // Try find KernelDef for node.device, else GPU or CPU.
+ for (const DeviceType& device :
+ {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
+ Status s = FindKernelDef(device, node, kdef, nullptr);
+ if (s.ok()) {
+ return Status::OK();
+ }
+ }
+
+ return errors::NotFound("Could not find KernelDef for op: ", node.op());
+}
+
+// Check if all node's inputs are pinned to CPU memory.
+bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
+ // Loop through all the inputs excluding the controlling nodes.
+ for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
+ // Check if (the fanin) op's device is on CPU.
+ if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
+ continue;
+ }
+
+ // Check if (the fanin) op's output port is pinned to HostMemory.
+ const OpDef* fanin_odef = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
+ return false;
+ }
+
+ const int output_arg_id =
+ OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
+ if (output_arg_id < 0) {
+ LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
+ << node.DebugString() << "\n"
+ << fanin.node->DebugString() << "\n"
+ << fanin_odef->DebugString();
+ return false;
+ }
+
+ const KernelDef* fanin_kdef = nullptr;
+ s = TryFindKernelDef(*fanin.node, &fanin_kdef);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
+ return false;
+ }
+
+ bool fanin_pinned = false;
+ for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
+ if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
+ fanin_pinned = true;
+ break;
+ }
+ }
+
+ if (!fanin_pinned) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
+ // Check if Tensor is integer and small size.
+
+ // Check type to be int32 or int64.
+ if (prop.dtype() != DataType::DT_INT32 &&
+ prop.dtype() != DataType::DT_INT64) {
+ return false;
+ }
+
+ // Check size known and small.
+ const int64 size = NumCoefficients(prop.shape());
+ if (size < 0 || size > kTensorMaxSize) {
+ return false;
+ }
+
+ return true;
+}
+
+bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
+ const NodeDef& node) {
+ for (const auto& prop : properties.GetInputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return false;
+ }
+ }
+
+ for (const auto& prop : properties.GetOutputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device) {
+ // Force this node onto the CPU.
+ if (device.empty() && has_device_cpu) {
+ return "/device:CPU:0";
+ } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ // Sometimes the cluster can have:
+ // devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
+ // and we need to handle them properly.
+ for (const auto& device_match :
+ {std::pair<string, string>("GPU", "CPU:0"),
+ std::pair<string, string>("/device", "/device:CPU:0")}) {
+ const string device_host =
+ strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ device_match.second);
+ if (devices.find(device_host) != devices.end()) {
+ return device_host;
+ }
+ }
+ }
+
+ // We couldn't find an appropriate Host device, return original device.
+ return device;
+}
+
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+
+// All the nodes that should be blacklisted and not swapped.
+bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); }
+} // end namespace internal
+
+Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+
+ // Skip all TPU graphs.
+ if (internal::IsTPUGraphDef(*optimized_graph)) {
+ return Status::OK();
+ }
+
+ GraphProperties properties(item);
+ bool has_properties = false;
+ GraphView graph(optimized_graph);
+
+ gtl::FlatSet<string> devices;
+ if (cluster) {
+ const std::vector<string> device_names = cluster->GetDeviceNames();
+ devices.insert(device_names.begin(), device_names.end());
+ } else {
+ devices = {"/device:CPU:0"};
+ }
+
+ const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
+
+ // Topologically sort the graph, so that we traverse the nodes in order. This
+ // will help us discover producer->consumer chains of Host ops.
+ TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+
+ // All the Const nodes, and their original devices in topological order.
+ std::vector<std::pair<NodeDef*, string>> const_nodes;
+
+ for (auto& node : *optimized_graph->mutable_node()) {
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ continue;
+ }
+
+ // Skip these node types.
+ if (internal::IsBlacklisted(node)) {
+ continue;
+ }
+
+ // Check the node can be run on CPU.
+ Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
+ if (!s.ok()) {
+ continue;
+ }
+
+ // Check all input's are pinned to CPU.
+ if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
+ continue;
+ }
+
+ if (!has_properties) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ has_properties = true;
+ }
+
+ // Check all inputs and outputs are integers and small.
+ if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
+ continue;
+ }
+
+ if (IsConstant(node)) {
+ const_nodes.emplace_back(&node, node.device());
+ }
+ // Try and swap the device to Host.
+ node.set_device(
+ internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
+ }
+
+ // Traverse all `const_nodes`, and map them back to GPU greedily.
+ for (auto& it : const_nodes) {
+ NodeDef* node = it.first;
+ const string& device = it.second;
+
+ // Check all the consumers of this node, if any of them are on the original
+ // device, swap this node back onto the original device.
+ for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
+ if (fanout.node->device() == device) {
+ node->set_device(device);
+ break;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
new file mode 100644
index 0000000000..d557a03463
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+
+#include <unordered_set>
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+// Try and find an appropriate Host device in `devices` given `device`.
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device);
+} // end namespace internal
+
+// Optimize TensorFlow ops that should be swapped into the CPU to avoid
+// excessive cpu<->gpu memcpy/sync.
+//
+// TODO(williamchan): The current heuristic will swap any small integer Const to
+// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of
+// gpu->gpu->gpu may have been better/faster. We should probably fix this.
+class PinToHostOptimizer : public GraphOptimizer {
+ public:
+ PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {}
+ explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level)
+ : opt_level_(opt_level) {}
+
+ ~PinToHostOptimizer() override {}
+
+ string name() const override { return "pin_to_host_optimizer"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ RewriterConfig::Toggle opt_level_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
new file mode 100644
index 0000000000..173cb3fe3c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -0,0 +1,194 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class PinToHostOptimizerTest : public GrapplerTest {};
+
+TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+ gtl::FlatSet<string> devices = {};
+ EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
+
+ devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
+ "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
+ "/device:CPU:0");
+
+ devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_CPU:0");
+
+ devices = {"/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_GPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_GPU:*");
+}
+
+TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, TopologicalSort) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ // Reverse the graph, and hence rely on the optimizer to sort it.
+ std::reverse(item.graph.mutable_node()->begin(),
+ item.graph.mutable_node()->end());
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, NoSwap) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `b` should be too big to swap, consequently `c` should not be swapped.
+ // PinToHostOptimizer should then detect that `a` should not be swapped.
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 1});
+ Output b = ops::Const(s.WithOpName("b"), 1, {1, 1024 * 1024});
+ Output c = ops::MatMul(s.WithOpName("c"), a, b);
+
+ GrapplerItem item;
+ item.fetch = {"a", "b", "c"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_TRUE(node.device().empty());
+ ++found;
+ }
+ EXPECT_EQ(found, 3);
+}
+
+TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GrapplerItem item;
+ item.fetch = {"a", "b"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ ++found;
+ }
+ EXPECT_EQ(found, 2);
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 03e36a7b9c..008a289cfd 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -218,7 +218,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
void Remapper::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
const GraphDef& /*optimized_graph*/,
double /*result*/) {
- // Nothing to do for ArithmeticOptimizer.
+ // Nothing to do for RemapperOptimizer.
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index caa0b7b0cb..4542d17ccc 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -20,10 +20,9 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
-
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 153785d3b4..db6e4e6852 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
+#include <iterator>
#include <memory>
#include <queue>
#include <vector>
@@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -165,6 +167,34 @@ int NodePosition(const string& name) {
return position;
}
+int NodePositionIfSameNode(const string& input_name, const string& node_name) {
+ const bool is_ctrl = input_name[0] == '^';
+ auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
+ auto node_it = node_name.begin();
+ if (node_name.empty() ||
+ std::distance(input_it, input_name.end()) < node_name.size()) {
+ return -2;
+ }
+ while (node_it != node_name.end()) {
+ if (*input_it++ != *node_it++) {
+ return -2;
+ }
+ }
+ if (input_it == input_name.end()) {
+ return is_ctrl ? -1 : 0;
+ } else if (*input_it++ == ':') {
+ StringPiece remaining(&(*input_it),
+ std::distance(input_it, input_name.end()));
+ int position;
+ if (!strings::safe_strto32(remaining, &position)) {
+ return -2;
+ }
+ return is_ctrl ? -1 : position;
+ } else {
+ return -2;
+ }
+}
+
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 20dbeea2cf..296ee1678e 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -107,6 +107,7 @@ bool IsSameInput(const string& name1, const string& name2);
string NodeName(const string& name);
// Get the trailing position number ":{digits}" (if any) of a node name.
+// Returns -1 for control inputs.
int NodePosition(const string& name);
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
@@ -142,6 +143,11 @@ inline string ParseNodeName(const string& name, int* position) {
return string(ParseNodeNameAsStringPiece(name, position));
}
+// Returns NodePosition(input_name) if NodeName(input_name) == node_name.
+// Otherwise returns -2;
+// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0.
+int NodePositionIfSameNode(const string& input_name, const string& node_name);
+
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter);
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index e540cc0476..bdbb8836e1 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -1,6 +1,10 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_protos_grappler",
+)
cc_library(
name = "scc",
@@ -210,3 +214,28 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
+
+cc_library(
+ name = "symbolic_shapes",
+ srcs = ["symbolic_shapes.cc"],
+ hdrs = ["symbolic_shapes.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ] + tf_protos_grappler(),
+)
+
+tf_cc_test(
+ name = "symbolic_shapes_test",
+ srcs = ["symbolic_shapes_test.cc"],
+ deps = [
+ ":symbolic_shapes",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 910b0acaef..6266733f3e 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() {
// optimizations interfering in the comparison.
RewriterConfig* cfg =
options_.config.mutable_graph_options()->mutable_rewrite_options();
- cfg->set_constant_folding(RewriterConfig::OFF);
+ // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
+ // off.
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
+ cfg->set_constant_folding(RewriterConfig::OFF);
+ cfg->set_debug_stripper(RewriterConfig::OFF);
cfg->set_dependency_optimization(RewriterConfig::OFF);
- cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_function_optimization(RewriterConfig::OFF);
cfg->set_layout_optimizer(RewriterConfig::OFF);
- cfg->set_debug_stripper(RewriterConfig::OFF);
+ cfg->set_loop_optimization(RewriterConfig::OFF);
+ cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(
diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h
index 4fb7aab647..ceb9f5dbf2 100644
--- a/tensorflow/core/grappler/utils/scc.h
+++ b/tensorflow/core/grappler/utils/scc.h
@@ -24,15 +24,16 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-// Compute modified strongly connected components:
+// Computes modified strongly connected components:
// All nodes that are not part of a loop are assigned the special -1 id
// All nodes that are part of at least one loop are assigned a positive
// component id: if 2 nodes v and w are reachable from one another (i.e. if they
// belong to the same scc), they'll be assigned the same id, otherwise they'll
-// be assigned distinct ids. Returns the number of distinct ids.
+// be assigned distinct ids. *num_components is set to the number of distinct
+// ids.
void StronglyConnectedComponents(
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
- int* num_ids);
+ int* num_components);
// Returns the number of individual loops present in the graph, and populate the
// 'loops' argument with the collection of loops (denoted by their loop ids) a
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/utils/symbolic_shapes.cc
index 155843a744..1666de4b80 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/utils/symbolic_shapes.h
index ace7bd1fe7..0a7d8ac82b 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@@ -74,4 +74,4 @@ int64 ComputeSizeRatio(const TensorShapeProto& numerator,
} // namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
index 7ce995d1c5..6ac644cdb1 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index c6e035834c..6b787a6910 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace grappler {
@@ -147,6 +148,21 @@ TEST_F(UtilsTest, NodePosition) {
EXPECT_EQ(0, NodePosition(""));
}
+TEST_F(UtilsTest, NodePositionIfSameNode) {
+ EXPECT_EQ(-2, NodePositionIfSameNode(":123", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode(":", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode("", ""));
+ EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc"));
+ EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc"));
+ EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz"));
+}
+
TEST_F(UtilsTest, AddNodeNamePrefix) {
EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED"));
EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED"));
@@ -209,7 +225,6 @@ TEST_F(UtilsTest, GetTailOfChain) {
auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop"));
GraphDef graph;
TF_CHECK_OK(s.ToGraphDef(&graph));
- LOG(INFO) << graph.DebugString();
ASSERT_EQ("c0", graph.node(0).name());
ASSERT_EQ("c1", graph.node(1).name());
@@ -336,9 +351,26 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
}
TEST_F(UtilsTest, DeleteNodes) {
- // TODO(rmlarsen): write forgtten test.
+ // TODO(rmlarsen): write forgotten test.
}
+#define BM_NodePositionIfSameNode(I, N, NAME) \
+ static void BM_NodePositionIfSameNode_##NAME(int iters) { \
+ string input = I; \
+ string node = N; \
+ for (int i = 0; i < iters; ++i) { \
+ const int pos = NodePositionIfSameNode(input, node); \
+ CHECK_GT(pos, -3); \
+ } \
+ } \
+ BENCHMARK(BM_NodePositionIfSameNode_##NAME)
+
+BM_NodePositionIfSameNode("foo/bar/baz:7", "foo/bar/baz", Match_7);
+BM_NodePositionIfSameNode("foo/bar/baz", "foo/bar/baz", Match_0);
+BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
+BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
+BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index c3c6013d83..1a3db2c7cd 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -30,6 +30,7 @@ load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_cc_test",
+ "tf_cc_test_mkl",
"tf_cc_tests",
"tf_cc_binary",
"tf_copts",
@@ -50,6 +51,10 @@ load(
"tf_kernel_tests_linkstatic",
)
load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
+load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
"if_mkl_ml",
@@ -212,6 +217,19 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "extract_volume_patches_op",
+ prefix = "extract_volume_patches_op",
+ deps = [
+ ":bounds_check",
+ ":eigen_helpers",
+ ":ops_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
cc_library(
name = "conv_3d",
hdrs = ["conv_3d.h"],
@@ -617,6 +635,7 @@ cc_library(
":diag_op",
":edit_distance_op",
":extract_image_patches_op",
+ ":extract_volume_patches_op",
":gather_nd_op",
":gather_op",
":guarantee_const_op",
@@ -636,6 +655,7 @@ cc_library(
":reshape_op",
":reverse_op",
":reverse_sequence_op",
+ ":searchsorted_op",
":shape_ops",
":slice_op",
":snapshot_op",
@@ -869,6 +889,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "searchsorted_op",
+ prefix = "searchsorted_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "inplace_ops",
prefix = "inplace_ops",
deps = ARRAY_DEPS,
@@ -1105,7 +1131,7 @@ tf_cuda_cc_test(
name = "depthwise_conv_ops_test",
size = "small",
srcs = ["depthwise_conv_ops_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":conv_ops",
":image",
@@ -2702,6 +2728,7 @@ cc_library(
)
LOGGING_DEPS = [
+ "@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -2759,6 +2786,7 @@ tf_cc_tests(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -4396,6 +4424,7 @@ cc_library(
":reduce_join_op",
":regex_full_match_op",
":regex_replace_op",
+ ":string_format_op",
":string_join_op",
":string_length_op",
":string_split_op",
@@ -4405,8 +4434,16 @@ cc_library(
],
)
+cc_library(
+ name = "string_util",
+ srcs = ["string_util.cc"],
+ hdrs = ["string_util.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
STRING_DEPS = [
":bounds_check",
+ ":string_util",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -4427,6 +4464,30 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "string_format_op",
+ prefix = "string_format_op",
+ deps = STRING_DEPS + ["@com_google_absl//absl/strings"],
+)
+
+tf_cc_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.cc"],
+ deps = [
+ ":string_format_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "string_join_op",
prefix = "string_join_op",
deps = STRING_DEPS,
@@ -4504,6 +4565,25 @@ tf_kernel_library(
deps = STRING_DEPS,
)
+tf_cc_test(
+ name = "substr_op_test",
+ size = "small",
+ srcs = ["substr_op_test.cc"],
+ deps = [
+ ":substr_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "as_string_op",
prefix = "as_string_op",
@@ -5094,6 +5174,7 @@ filegroup(
"spacetobatch_functor.h",
"spacetodepth_op.h",
"spectrogram.h",
+ "string_util.h",
"tensor_array.h",
"tile_functor.h",
"tile_ops_cpu_impl.h",
@@ -5262,6 +5343,7 @@ filegroup(
"spectrogram_op.cc",
"stack_ops.cc",
"string_join_op.cc",
+ "string_util.cc",
"summary_op.cc",
"tensor_array.cc",
"tensor_array_ops.cc",
@@ -6209,6 +6291,26 @@ tf_mkl_kernel_library(
] + mkl_deps(),
)
+tf_cc_test_mkl(
+ name = "mkl_conv_ops_test",
+ size = "small",
+ srcs = ["mkl_conv_ops_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_mkl_kernel_library(
name = "mkl_tfconv_op",
prefix = "mkl_tfconv",
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 7b28c8e91f..e15ea82e7d 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -134,8 +134,8 @@ class BiasOp : public BinaryOp<T> {
if (data_format_ == FORMAT_NCHW) {
int32 batch, height, width, channel;
GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
- Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
- Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
+ Eigen::DSizes<Eigen::Index, 4> four_dims(1, channel, 1, 1);
+ Eigen::DSizes<Eigen::Index, 4> broad_cast_dims(batch, 1, height, width);
const Device& d = context->eigen_device<Device>();
output->tensor<T, 4>().device(d) =
input.tensor<T, 4>() +
@@ -247,14 +247,14 @@ class BiasGradOp : public OpKernel {
OP_REQUIRES(context, output_backprop.dims() == 4,
errors::InvalidArgument(
"NCHW format supports only 4D input/output tensor."));
- Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
+ Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width);
#ifdef EIGEN_HAS_INDEX_LIST
using idx0 = Eigen::type2index<0>;
using idx2 = Eigen::type2index<2>;
using idx3 = Eigen::type2index<3>;
Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
#else
- Eigen::array<int, 3> reduction_axes = {0, 2, 3};
+ Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
@@ -263,11 +263,12 @@ class BiasGradOp : public OpKernel {
.sum(reduction_axes)
.template cast<T>(); // End of code by intel_tf.
} else {
- Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
+ Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width,
+ channel);
#ifdef EIGEN_HAS_INDEX_LIST
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
#else
- Eigen::array<int, 1> reduction_axis = {0};
+ Eigen::array<Eigen::Index, 1> reduction_axis = {0};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
index 6074b3e1f6..7d09e9b820 100644
--- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 4910021c63..4e8bfa02fc 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -15,7 +15,9 @@ load(
tf_proto_library(
name = "boosted_trees_proto",
- srcs = ["boosted_trees.proto"],
+ srcs = [
+ "boosted_trees.proto",
+ ],
cc_api_version = 2,
visibility = ["//visibility:public"],
)
@@ -87,9 +89,21 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "quantile_ops",
+ srcs = ["quantile_ops.cc"],
+ deps = [
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
+ ],
+)
+
+tf_kernel_library(
name = "boosted_trees_ops",
deps = [
":prediction_ops",
+ ":quantile_ops",
":resource_ops",
":stats_ops",
":training_ops",
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
index c9664f0c1c..1ab72af059 100644
--- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
@@ -11,6 +11,7 @@ message Node {
oneof node {
Leaf leaf = 1;
BucketizedSplit bucketized_split = 2;
+ CategoricalSplit categorical_split = 3;
}
NodeMetadata metadata = 777;
}
@@ -57,6 +58,18 @@ message BucketizedSplit {
int32 right_id = 4;
}
+message CategoricalSplit {
+ // Categorical feature column and split describing the rule feature value ==
+ // value.
+ int32 feature_id = 1;
+ int32 value = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index b2efa06941..4ae26fb95b 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -334,30 +334,34 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
// Proto to store debug outputs, per example.
boosted_trees::DebugOutput example_debug_info;
// Initial bias prediction. E.g., prediction based off training mean.
- example_debug_info.add_logits_path(resource->GetTreeWeight(0) *
- resource->node_value(0, 0));
+ float tree_logit =
+ resource->GetTreeWeight(0) * resource->node_value(0, 0);
+ example_debug_info.add_logits_path(tree_logit);
int32 node_id = 0;
int32 tree_id = 0;
int32 feature_id;
- float tree_logit;
float past_trees_logit = 0; // Sum of leaf logits from prior trees.
- // Populate proto.
+ // Go through each tree and populate proto.
while (tree_id <= last_tree) {
- // Feature id used to split.
- feature_id = resource->feature_id(tree_id, node_id);
- example_debug_info.add_feature_ids(feature_id);
- // Get logit after split.
- node_id = resource->next_node(tree_id, node_id, i,
- batch_bucketized_features);
- tree_logit = resource->GetTreeWeight(tree_id) *
- resource->node_value(tree_id, node_id);
- // Output logit incorporates sum of leaf logits from prior trees.
- example_debug_info.add_logits_path(tree_logit + past_trees_logit);
- if (resource->is_leaf(tree_id, node_id)) {
- // Move onto other trees.
- past_trees_logit += tree_logit;
+ if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees.
+ // Accumulate tree_logits only if the leaf is non-root, but do so
+ // for bias tree.
+ if (tree_id == 0 || node_id > 0) {
+ past_trees_logit += tree_logit;
+ }
++tree_id;
node_id = 0;
+ } else { // Add to proto.
+ // Feature id used to split.
+ feature_id = resource->feature_id(tree_id, node_id);
+ example_debug_info.add_feature_ids(feature_id);
+ // Get logit after split.
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ tree_logit = resource->GetTreeWeight(tree_id) *
+ resource->node_value(tree_id, node_id);
+ // Output logit incorporates sum of leaf logits from prior trees.
+ example_debug_info.add_logits_path(tree_logit + past_trees_logit);
}
}
// Set output as serialized proto containing debug info.
diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
new file mode 100644
index 0000000000..d1840941c1
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
@@ -0,0 +1,453 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+const char* const kExampleWeightsName = "example_weights";
+const char* const kMaxElementsName = "max_elements";
+const char* const kGenerateQuantiles = "generate_quantiles";
+const char* const kNumBucketsName = "num_buckets";
+const char* const kEpsilonName = "epsilon";
+const char* const kBucketBoundariesName = "bucket_boundaries";
+const char* const kBucketsName = "buckets";
+const char* const kSummariesName = "summaries";
+const char* const kNumStreamsName = "num_streams";
+const char* const kNumFeaturesName = "num_features";
+const char* const kFloatFeaturesName = "float_values";
+const char* const kResourceHandleName = "quantile_stream_resource_handle";
+
+using QuantileStreamResource = BoostedTreesQuantileStreamResource;
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+using QuantileSummary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using QuantileSummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateBoundaries(const QuantileStream& stream,
+ const int64 num_boundaries) {
+ std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
+
+ // Uniquify elements as we may get dupes.
+ auto end_it = std::unique(boundaries.begin(), boundaries.end());
+ boundaries.resize(std::distance(boundaries.begin(), end_it));
+ return boundaries;
+}
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateQuantiles(const QuantileStream& stream,
+ const int64 num_quantiles) {
+ // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
+ // will be returned.
+ std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
+ CHECK_EQ(boundaries.size(), num_quantiles);
+ return boundaries;
+}
+
+std::vector<float> GetBuckets(const int32 feature,
+ const OpInputList& buckets_list) {
+ const auto& buckets = buckets_list[feature].flat<float>();
+ std::vector<float> buckets_vector(buckets.data(),
+ buckets.data() + buckets.size());
+ return buckets_vector;
+}
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
+
+REGISTER_KERNEL_BUILDER(
+ Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
+ IsResourceInitialized<BoostedTreesQuantileStreamResource>);
+
+class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
+ public:
+ explicit BoostedTreesCreateQuantileStreamResourceOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions. If one already exists, it unrefs the new one.
+ // An epsilon value of zero could cause perfoamance issues and is therefore,
+ // disallowed.
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+ OP_REQUIRES(
+ context, epsilon > 0,
+ errors::InvalidArgument("An epsilon value of zero is not allowed."));
+
+ const Tensor* num_streams_t;
+ OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
+ int64 num_streams = num_streams_t->scalar<int64>()();
+
+ auto result =
+ new QuantileStreamResource(epsilon, max_elements_, num_streams);
+ auto status = CreateResource(context, HandleFromInput(context, 0), result);
+ if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES(context, false, status);
+ }
+ }
+
+ private:
+ // An upper bound on the number of entries that the summaries might have
+ // for a feature.
+ int64 max_elements_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
+ BoostedTreesCreateQuantileStreamResourceOp);
+
+class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesMakeQuantileSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+
+ // Parse example weights and get batch size.
+ const Tensor* example_weights_t;
+ OP_REQUIRES_OK(context,
+ context->input(kExampleWeightsName, &example_weights_t));
+ auto example_weights = example_weights_t->flat<float>();
+ const int64 batch_size = example_weights.size();
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+
+ OpOutputList summaries_output_list;
+ OP_REQUIRES_OK(
+ context, context->output_list(kSummariesName, &summaries_output_list));
+
+ auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
+ // Iterating features.
+ for (int64 index = begin; index < end; index++) {
+ const auto feature_values = float_features_list[index].flat<float>();
+ QuantileStream stream(epsilon, batch_size + 1);
+ // Run quantile summary generation.
+ for (int64 j = 0; j < batch_size; j++) {
+ stream.PushEntry(feature_values(j), example_weights(j));
+ }
+ stream.Finalize();
+ const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
+ Tensor* output_t;
+ OP_REQUIRES_OK(
+ context,
+ summaries_output_list.allocate(
+ index,
+ TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
+ &output_t));
+ auto output = output_t->matrix<float>();
+ for (auto row = 0; row < summary_entry_list.size(); row++) {
+ const auto& entry = summary_entry_list[row];
+ output(row, 0) = entry.value;
+ output(row, 1) = entry.weight;
+ output(row, 2) = entry.min_rank;
+ output(row, 3) = entry.max_rank;
+ }
+ }
+ };
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * batch_size;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_summary_gen);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
+ BoostedTreesMakeQuantileSummariesOp);
+
+class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ OpInputList summaries_list;
+ OP_REQUIRES_OK(context,
+ context->input_list(kSummariesName, &summaries_list));
+ int32 num_streams = stream_resource->num_streams();
+ CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
+
+ auto do_quantile_add_summary = [&](const int64 begin, const int64 end) {
+ // Iterating all features.
+ for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) {
+ const Tensor& summaries = summaries_list[feature_idx];
+ const auto summary_values = summaries.matrix<float>();
+ const auto& tensor_shape = summaries.shape();
+ const int64 entries_size = tensor_shape.dim_size(0);
+ CHECK_EQ(tensor_shape.dim_size(1), 4);
+ std::vector<QuantileSummaryEntry> summary_entries;
+ summary_entries.reserve(entries_size);
+ for (int64 i = 0; i < entries_size; i++) {
+ float value = summary_values(i, 0);
+ float weight = summary_values(i, 1);
+ float min_rank = summary_values(i, 2);
+ float max_rank = summary_values(i, 3);
+ QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
+ summary_entries.push_back(entry);
+ }
+ stream_resource->stream(feature_idx)->PushSummary(summary_entries);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_add_summary);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceAddSummariesOp);
+
+class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceFlushOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const Tensor* num_buckets_t;
+ OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
+ const int64 num_buckets = num_buckets_t->scalar<int64>()();
+ const int64 num_streams = stream_resource->num_streams();
+
+ auto do_quantile_flush = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) {
+ QuantileStream* stream = stream_resource->stream(stream_idx);
+ stream->Finalize();
+ stream_resource->set_boundaries(
+ generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
+ : GenerateBoundaries(*stream, num_buckets),
+ stream_idx);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_flush);
+
+ stream_resource->set_buckets_ready(true);
+ }
+
+ private:
+ bool generate_quantiles_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceFlushOp);
+
+class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
+ : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const int64 num_streams = stream_resource->num_streams();
+ CHECK_EQ(num_features_, num_streams);
+ OpOutputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+
+ auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; stream_idx++) {
+ const auto& boundaries = stream_resource->boundaries(stream_idx);
+ Tensor* bucket_boundaries_t = nullptr;
+ OP_REQUIRES_OK(context,
+ bucket_boundaries_list.allocate(
+ stream_idx, {static_cast<int64>(boundaries.size())},
+ &bucket_boundaries_t));
+ auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
+ memcpy(quantiles_flat, boundaries.data(),
+ sizeof(float) * boundaries.size());
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_get_buckets);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
+
+// Given the calculated quantiles thresholds and input data, this operation
+// converts the input features into the buckets (categorical values), depending
+// on which quantile they fall into.
+class BoostedTreesBucketizeOp : public OpKernel {
+ public:
+ explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+ OpInputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+ OP_REQUIRES(context,
+ tensorflow::TensorShapeUtils::IsVector(
+ bucket_boundaries_list[0].shape()),
+ errors::InvalidArgument(
+ strings::Printf("Buckets should be flat vectors.")));
+ OpOutputList buckets_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
+
+ auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) {
+ // Iterating over all resources
+ for (int64 feature_idx = begin; feature_idx < end; feature_idx++) {
+ const Tensor& values_tensor = float_features_list[feature_idx];
+ const int64 num_values = values_tensor.dim_size(0);
+
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(
+ context, buckets_list.allocate(
+ feature_idx, TensorShape({num_values, 1}), &output_t));
+ auto output = output_t->matrix<int32>();
+
+ const std::vector<float>& bucket_boundaries_vector =
+ GetBuckets(feature_idx, bucket_boundaries_list);
+ CHECK(!bucket_boundaries_vector.empty())
+ << "Got empty buckets for feature " << feature_idx;
+ auto flat_values = values_tensor.flat<float>();
+ for (int64 instance = 0; instance < num_values; instance++) {
+ const float value = flat_values(instance);
+ auto bucket_iter =
+ std::lower_bound(bucket_boundaries_vector.begin(),
+ bucket_boundaries_vector.end(), value);
+ if (bucket_iter == bucket_boundaries_vector.end()) {
+ --bucket_iter;
+ }
+ const int32 bucket = static_cast<int32>(
+ bucket_iter - bucket_boundaries_vector.begin());
+ // Bucket id.
+ output(instance, 0) = bucket;
+ }
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_features_;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_get_quantiles);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
+ BoostedTreesBucketizeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
index 3163c63949..12d9473776 100644
--- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -1,5 +1,5 @@
# Description:
-# This directory contains common utilities used in boosted_trees.
+# This directory contains common quantile utilities used in boosted_trees.
package(
default_visibility = ["//tensorflow:internal"],
)
@@ -16,6 +16,7 @@ cc_library(
name = "weighted_quantiles",
srcs = [],
hdrs = [
+ "quantile_stream_resource.h",
"weighted_quantiles_buffer.h",
"weighted_quantiles_stream.h",
"weighted_quantiles_summary.h",
@@ -23,6 +24,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
new file mode 100644
index 0000000000..1c31724272
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
@@ -0,0 +1,96 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+
+#include <vector>
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+
+// Quantile Stream Resource for a list of streams sharing the same number of
+// quantiles, maximum elements, and epsilon.
+class BoostedTreesQuantileStreamResource : public ResourceBase {
+ public:
+ BoostedTreesQuantileStreamResource(const float epsilon,
+ const int64 max_elements,
+ const int64 num_streams)
+ : are_buckets_ready_(false),
+ epsilon_(epsilon),
+ num_streams_(num_streams),
+ max_elements_(max_elements) {
+ streams_.reserve(num_streams_);
+ boundaries_.reserve(num_streams_);
+ for (int64 idx = 0; idx < num_streams; ++idx) {
+ streams_.push_back(QuantileStream(epsilon, max_elements));
+ boundaries_.push_back(std::vector<float>());
+ }
+ }
+
+ string DebugString() override { return "QuantileStreamResource"; }
+
+ tensorflow::mutex* mutex() { return &mu_; }
+
+ QuantileStream* stream(const int64 index) { return &streams_[index]; }
+
+ const std::vector<float>& boundaries(const int64 index) {
+ return boundaries_[index];
+ }
+
+ void set_boundaries(const std::vector<float>& boundaries, const int64 index) {
+ boundaries_[index] = boundaries;
+ }
+
+ float epsilon() const { return epsilon_; }
+ int64 num_streams() const { return num_streams_; }
+
+ bool are_buckets_ready() const { return are_buckets_ready_; }
+ void set_buckets_ready(const bool are_buckets_ready) {
+ are_buckets_ready_ = are_buckets_ready;
+ }
+
+ private:
+ ~BoostedTreesQuantileStreamResource() override {}
+
+ // Mutex for the whole resource.
+ tensorflow::mutex mu_;
+
+ // Quantile streams.
+ std::vector<QuantileStream> streams_;
+
+ // Stores the boundaries. Same size as streams_.
+ std::vector<std::vector<float>> boundaries_;
+
+ // Whether boundaries are created. Initially boundaries are empty until
+ // set_boundaries are called.
+ bool are_buckets_ready_;
+
+ const float epsilon_;
+ const int64 num_streams_;
+ // An upper-bound for the number of elements.
+ int64 max_elements_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index cc90bb2f45..2798722536 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -60,14 +60,26 @@ int32 BoostedTreesEnsembleResource::next_node(
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- const auto& split = node.bucketized_split();
- if (bucketized_features[split.feature_id()](index_in_batch) <=
- split.threshold()) {
- return split.left_id();
- } else {
- return split.right_id();
+
+ switch (node.node_case()) {
+ case boosted_trees::Node::kBucketizedSplit: {
+ const auto& split = node.bucketized_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) <=
+ split.threshold())
+ ? split.left_id()
+ : split.right_id();
+ }
+ case boosted_trees::Node::kCategoricalSplit: {
+ const auto& split = node.categorical_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) ==
+ split.value())
+ ? split.left_id()
+ : split.right_id();
+ }
+ default:
+ DCHECK(false) << "Node type " << node.node_case() << " not supported.";
}
+ return -1;
}
float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index de9b69828e..639c3062cc 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -137,17 +137,16 @@ struct MatMulConvFunctor {
}
};
-// Shuffles a filter tensor from:
-// [<spatial_dims>, in, out]
-// to:
-// [out, in, <spatial_dims>]
+// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
+//
+// Note: Currently OIHW is the only supported destination format. Support for
+// OHWI format will be added in a follow-up change.
template <typename Device, typename T, typename IndexType, int NDIMS>
struct TransformFilter {
- void operator()(const Device& d,
+ void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
typename TTypes<T, NDIMS, IndexType>::Tensor out) {
- // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
- // to speed up the shuffle operation.
+ // Merge the spatial dimensions together to speed up the shuffle operation.
Eigen::DSizes<IndexType, 3> merged_dims;
merged_dims[0] = in.dimension(0); // spatial dimensions
for (int i = 1; i < NDIMS - 2; ++i) {
@@ -156,16 +155,30 @@ struct TransformFilter {
merged_dims[1] = in.dimension(NDIMS - 2); // input filters
merged_dims[2] = in.dimension(NDIMS - 1); // output filters
+ CHECK(dst_filter_format == FORMAT_OIHW)
+ << "Unsupported destination filter format: "
+ << ToString(dst_filter_format);
+ // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
+ // in the beginning.
+ Eigen::DSizes<IndexType, 3> shuffling_perm =
+ Eigen::DSizes<IndexType, 3>(2, 1, 0);
+
Eigen::DSizes<IndexType, NDIMS> expanded_dims;
- expanded_dims[0] = in.dimension(NDIMS - 1); // output filters
- expanded_dims[1] = in.dimension(NDIMS - 2); // input filters
- for (int i = 0; i < NDIMS - 2; ++i) { // spatial dimensions
- expanded_dims[i + 2] = in.dimension(i);
+ int out_index = 0;
+ for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
+ if (shuffling_perm[merged_dim] == 0) {
+ for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
+ expanded_dims[out_index++] = in.dimension(spatial_dim);
+ }
+ } else {
+ constexpr int kLastSpatialDim = NDIMS - 3;
+ expanded_dims[out_index++] =
+ in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
+ }
}
- out.device(d) = in.reshape(merged_dims)
- .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
- .reshape(expanded_dims);
+ out.device(d) =
+ in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims);
}
};
@@ -282,7 +295,9 @@ struct SwapDimension0And2InTensor3 {
const gtl::ArraySlice<int64>& input_dims, T* out);
};
-// Reverses the effect of TransformFilter above.
+// Transforms back filter from OIHW to HWOI format to reverse effect of
+// TransformFilter above.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
template <typename Device, typename T, int NDIMS>
struct ReverseTransformFilter {
void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
index 02e3655ad1..b819c6f910 100644
--- a/tensorflow/core/kernels/conv_3d.h
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_CONV_3D_H_
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
namespace tensorflow {
@@ -28,6 +29,14 @@ namespace functor {
template <typename Device, typename T>
struct CuboidConvolution;
+// Backward input pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardInput;
+
+// Backward filter pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardFilter;
+
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename T>
@@ -42,6 +51,40 @@ struct CuboidConvolution<CPUDevice, T> {
}
};
+template <typename T>
+struct CuboidConvolutionBackwardInput<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor input_backward,
+ typename TTypes<T, 5>::ConstTensor filter,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward,
+ input_backward.dimension(3), // input_planes
+ input_backward.dimension(2), // input_rows
+ input_backward.dimension(1), // input_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
+template <typename T>
+struct CuboidConvolutionBackwardFilter<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor filter_backward,
+ typename TTypes<T, 5>::ConstTensor input,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward,
+ filter_backward.dimension(2), // kernel_planes
+ filter_backward.dimension(1), // kernel_rows
+ filter_backward.dimension(0), // kernel_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 63b1bcda43..9e86a16b66 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -1018,7 +1018,8 @@ namespace functor {
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index d664a11e73..43bb5ea56c 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -901,7 +901,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_out_backprop;
@@ -1090,7 +1091,8 @@ namespace functor {
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index fc0a2f123f..507720c998 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -41,6 +41,17 @@ limitations under the License.
namespace tensorflow {
+// Compute padding for the given spatial dimension.
+int ConvBackpropDimensions::SpatialPadding(const Padding& padding,
+ int dim) const {
+ return (padding == VALID)
+ ? 0
+ : std::max<int>(
+ 0, static_cast<int>((output_size(dim) - 1) * stride(dim) +
+ (filter_size(dim) - 1) * dilation(dim) +
+ 1 - input_size(dim)));
+}
+
// The V2 version computes windowed output size with arbitrary dilation_rate,
// while the original version only handles the cases where dilation_rates equal
// to 1.
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index 535586d53a..9551959463 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -234,6 +234,16 @@ struct ConvBackpropDimensions {
// Input and output feature depth.
int64 in_depth, out_depth;
+
+ // Convenience access methods for spatial dimensions properties.
+ int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
+ int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
+ int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
+ int64 stride(int dim) const { return spatial_dims[dim].stride; }
+ int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
+
+ // Compute padding for the given spatial dimension.
+ int SpatialPadding(const Padding& padding, int dim) const;
};
// Common code between implementations of Conv?DBackpropInput and
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 15f1bf9aba..bab91f5e86 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -32,111 +33,130 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
using stream_executor::dnn::DimIndex;
#endif
+namespace {
+
+// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
+// conv_grad_input_ops_3d.cc.
+
+// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
+
+// "Depth" is already used for the channel dimension, so for the third spatial
+// dimension in this file we use "plane", although in NDHWC layout it's
+// indicated with a "D".
+
+// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
+// order (planes, height, width, depth), constructed from patches in 'col_data',
+// which is required to be in storage order (out_planes * out_height *
+// out_width, filter_planes, filter_height, filter_width, in_depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Col2im(const T* col_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* im_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ T* im_patch_data =
+ im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ for (int i = 0; i < depth; ++i) {
+ im_patch_data[i] += col_data[i];
+ }
+ }
+ im_patch_data += depth;
+ col_data += depth;
+ }
+ // Jump over remaining number of depth.
+ im_patch_data += depth * (width - filter_w);
+ }
+ // Jump over remaining number of (depth * width).
+ im_patch_data += (depth * width) * (height - filter_h);
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+// Returns in 'col_data', image patches in storage order (planes, height, width,
+// depth) extracted from image at 'input_data', which is required to be in
+// storage order (batch, planes, height, width, depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Im2col(const T* input_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* col_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ memcpy(col_data,
+ input_data +
+ (ip * height * width + ih * width + iw) * depth,
+ sizeof(T) * depth);
+ } else {
+ // This should be simply padded with zero.
+ memset(col_data, 0, sizeof(T) * depth);
+ }
+ col_data += depth;
+ }
+ }
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+} // namespace
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-// TODO(mjanusz): Get rid of the macro and return shapes directly.
-#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
- const Tensor& out_backprop = context->input(2); \
- OP_REQUIRES( \
- context, input_shape.dims() == 5, \
- errors::InvalidArgument(label, ": input must be 5-dimensional")); \
- OP_REQUIRES( \
- context, filter_shape.dims() == 5, \
- errors::InvalidArgument(label, ": filter must be 5-dimensional")); \
- OP_REQUIRES( \
- context, out_backprop.dims() == 5, \
- errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
- const int64 batch = input_shape.dim_size(0); \
- OP_REQUIRES( \
- context, batch == out_backprop.dim_size(0), \
- errors::InvalidArgument( \
- label, ": input and out_backprop must have the same batch size")); \
- const std::array<int64, 3> input_size = { \
- {GetTensorDim(input_shape, data_format_, '0'), \
- GetTensorDim(input_shape, data_format_, '1'), \
- GetTensorDim(input_shape, data_format_, '2')}}; \
- const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \
- const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \
- filter_shape.dim_size(1), \
- filter_shape.dim_size(2)}}; \
- const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \
- const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \
- const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \
- OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \
- errors::InvalidArgument( \
- label, ": input and filter must have the same depth")); \
- const int64 out_depth = filter_shape.dim_size(4); \
- OP_REQUIRES( \
- context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \
- errors::InvalidArgument( \
- label, ": filter and out_backprop must have the same out_depth")); \
- const std::array<int64, 3> dilations = { \
- {GetTensorDim(dilation_, data_format_, '0'), \
- GetTensorDim(dilation_, data_format_, '1'), \
- GetTensorDim(dilation_, data_format_, '2')}}; \
- const std::array<int64, 3> strides = { \
- {GetTensorDim(stride_, data_format_, '0'), \
- GetTensorDim(stride_, data_format_, '1'), \
- GetTensorDim(stride_, data_format_, '2')}}; \
- std::array<int64, 3> out, padding; \
- OP_REQUIRES_OK( \
- context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \
- padding_, &out, &padding)); \
- OP_REQUIRES(context, output_planes == out[0], \
- errors::InvalidArgument( \
- label, \
- ": Number of planes of out_backprop doesn't match " \
- "computed: actual = ", \
- output_planes, ", computed = ", out[0])); \
- OP_REQUIRES( \
- context, output_rows == out[1], \
- errors::InvalidArgument( \
- label, ": Number of rows of out_backprop doesn't match computed: ", \
- "actual = ", output_rows, ", computed = ", out[1])); \
- OP_REQUIRES( \
- context, output_cols == out[2], \
- errors::InvalidArgument( \
- label, ": Number of cols of out_backprop doesn't match computed: ", \
- "actual = ", output_cols, ", computed = ", out[2])); \
- const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \
- const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \
- const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \
- const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \
- const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \
- const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \
- const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \
- const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \
- const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \
- const auto bottom_pad_planes = \
- padded_out_planes - expanded_out_planes - top_pad_planes; \
- const auto bottom_pad_rows = \
- padded_out_rows - expanded_out_rows - top_pad_rows; \
- const auto right_pad_cols = \
- padded_out_cols - expanded_out_cols - left_pad_cols; \
- VLOG(2) << "Conv3d: " << label \
- << ": expanded_out_planes = " << expanded_out_planes \
- << ": expanded_out_rows = " << expanded_out_rows \
- << ", expanded_out_cols = " << expanded_out_cols \
- << ", padded_out_planes = " << padded_out_planes \
- << ", padded_out_rows = " << padded_out_rows \
- << ", padded_out_cols = " << padded_out_cols \
- << ", top_pad_planes = " << top_pad_planes \
- << ", top_pad_rows = " << top_pad_rows \
- << ", left_pad_cols = " << left_pad_cols \
- << ", bottom_pad_planes = " << bottom_pad_planes \
- << ", bottom_pad_rows = " << bottom_pad_rows \
- << ", right_pad_cols = " << right_pad_cols
-
-// Backprop for input.
+// Backprop for input that offloads computation to
+// Eigen::CuboidConvolutionBackwardInput.
template <typename Device, class T>
class Conv3DBackpropInputOp : public OpKernel {
public:
@@ -192,6 +212,116 @@ class Conv3DBackpropInputOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape input_shape;
+ if (takes_shape_) {
+ const Tensor& input_sizes = context->input(0);
+ // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+ OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
+ } else {
+ input_shape = context->input(0).shape();
+ }
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
+ Tensor* in_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &in_backprop));
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
+};
+
+// Custom backprop for input that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropInputOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& filter = context->input(1);
+ const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -200,51 +330,239 @@ class Conv3DBackpropInputOp : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
- // Fill out a padded out_backprop.
- TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
- padded_out_cols, out_depth});
- Tensor padded_output;
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // Use L3 cache size as target working set size.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ // Calculate size of matrices involved in MatMul: C = A x B.
+ const int64 size_A = output_image_size * dims.out_depth;
+
+ const int64 size_B = filter_total_size * dims.out_depth;
+
+ const int64 size_C = output_image_size * filter_total_size;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ // Use parallel tensor contractions if there is no batching.
+ //
+ // Compared to Conv2D code, this version is missing work size estimation. In
+ // benchmarks I didn't find a case when it's beneficial to run parallel
+ // contraction compared to sharding and matmuls.
+ const bool use_parallel_contraction = dims.batch_size == 1;
+
+ const size_t shard_size =
+ use_parallel_contraction
+ ? 1
+ : (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // Fill a new "reverted" filter. We need to transpose the in_depth and
- // out_depth for the filter and reverse the planes, rows and cols.
- TensorShape r_filter_shape(
- {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
- Tensor r_filter;
- OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
- r_filter_shape, &r_filter));
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
- filter_rev_dims, r_filter.tensor<T, 5>());
- const Tensor& r_filter_cref = r_filter;
-
- // Now we can call conv_3d directly.
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* filter_data = filter.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+
+ auto in_backprop_flat = in_backprop->template flat<T>();
+ T* input_backprop_data = in_backprop_flat.data();
+ in_backprop_flat.device(context->eigen_device<Device>()) =
+ in_backprop_flat.constant(T(0));
+
+ if (use_parallel_contraction) {
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ // Initialize contraction dims (we need to transpose 'B' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 1;
+ contract_dims[0].second = 1;
+
+ for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
+ // Compute gradient into col_buffer.
+ TensorMap C(col_buffer_data, output_image_size, filter_total_size);
+
+ ConstTensorMap A(out_backprop_data + output_offset * image_id,
+ output_image_size, dims.out_depth);
+ ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
+
+ Col2im<T>(col_buffer_data, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_backprop_data);
+
+ input_backprop_data += input_offset;
+ }
+ } else {
+ typedef Eigen::Map<
+ Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ MatrixMap;
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>>
+ ConstMatrixMap;
+
+ for (int image_id = 0; image_id < dims.batch_size;
+ image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
+ &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
+ &output_image_size, &filter_total_size,
+ &input_backprop_data, &col_buffer_data,
+ &out_backprop_data, &filter_data, &input_offset,
+ &output_offset, &size_C](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ T* im2col_buf = col_buffer_data + shard_id * size_C;
+ T* input_data = input_backprop_data + shard_id * input_offset;
+ const T* out_data = out_backprop_data + shard_id * output_offset;
+
+ // Compute gradient into 'im2col_buf'.
+ MatrixMap C(im2col_buf, output_image_size, filter_total_size);
+
+ ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
+ ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.noalias() = A * B.transpose();
+
+ Col2im<T>(im2col_buf, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_data);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ work_unit_size, shard);
+
+ input_backprop_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
+ }
}
private:
@@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>); \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>);
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>);
+
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
-// Backprop for filter.
+// Backprop for filter that offloads computation to
+// Eigen::CuboidConvolutionBackwardFilter.
template <typename Device, class T>
class Conv3DBackpropFilterOp : public OpKernel {
public:
@@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape;
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
@@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, filter_shape, &filter_backprop));
@@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel {
return;
}
- // For the backprop of the filter, we need to also transpose the
- // out_backprop.
- // The shape of backprop is
- // [batch, out_z, out_y, out_x, out_depth]
- // And we need to change it to
- // [out_depth, out_x, out_y, out_z, batch]
- Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
- TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
- padded_out_cols, batch});
- Tensor padded_output;
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
+};
+
+// Custom backprop for filter that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropFilterOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
+ if (takes_shape_) {
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_sizes.vec<int32>(), &filter_shape));
+ } else {
+ filter_shape = context->input(1).shape();
+ }
+
+ ConvBackpropDimensions dims;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // For the backprop of the filter, we need to transpose the input.
- // The shape of input is
- // [batch, in_z, in_y, in_x, in_depth]
- // And we need to change it to
- // [in_z, in_y, in_x, batch, in_depth]
- Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
- TensorShape in_shuffle_shape(
- {input_size[0], input_size[1], input_size[2], batch, in_depth});
- Tensor in_shuffle;
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
+ Tensor* filter_backprop;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- in_shuffle_shape, &in_shuffle));
- // No need for reversing this time.
- Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
- no_reverse, in_shuffle.tensor<T, 5>());
- const Tensor& in_shuffle_cref = in_shuffle;
-
- // The output of the conv_3d would be
- // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
- // and we need to shuffle it back to
- // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
- // And we need to reverse the filter backprops.
- // So we need to allocate (sigh) yet another piece of memory to hold the
- // output.
- TensorShape filter_shuffle_shape(
- {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
- Tensor filter_shuffle;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::v(),
- filter_shuffle_shape, &filter_shuffle));
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
-
- // Now copy the filter_backprop back to the destination.
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- const Tensor& filter_shuffle_cref = filter_shuffle;
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
- filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ if (input_shape.num_elements() == 0) {
+ filter_backprop->template flat<T>().setZero();
+ return;
+ }
+
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ // Shard 'batch' images (volumes) into 'shard_size' groups of images
+ // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
+ // dividing the L3 cache size ('target_working_set_size') by the matmul size
+ // of an individual image ('work_unit_size').
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // TODO(andydavis)
+ // *) Consider reducing 'target_working_set_size' if L3 is shared by
+ // other concurrently running tensorflow ops.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ const int64 size_A = output_image_size * filter_total_size;
+
+ const int64 size_B = output_image_size * dims.out_depth;
+
+ const int64 size_C = filter_total_size * dims.out_depth;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ const size_t shard_size =
+ (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* input_data = input.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+ T* filter_backprop_data = filter_backprop->template flat<T>().data();
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
+ C.setZero();
+
+ // Initialize contraction dims (we need to transpose 'A' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 0;
+ contract_dims[0].second = 0;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
+ &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
+ &bottom_pad_rows, &right_pad_cols, &input_offset,
+ &size_A](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ const T* input_data_shard = input_data + shard_id * input_offset;
+ T* col_data_shard = col_buffer_data + shard_id * size_A;
+
+ // When we compute the gradient with respect to the filters, we need
+ // to do im2col to allow gemm-type computation.
+ Im2col<T>(input_data_shard, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ col_data_shard);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ size_A, shard);
+
+ ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+ filter_total_size);
+ ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+ dims.out_depth);
+
+ // Gradient with respect to filter.
+ C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
+
+ input_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
}
private:
@@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropFilterOp<CPUDevice, T>); \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
.TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<CPUDevice, T>);
-TF_CALL_half(REGISTER_CPU_KERNEL);
+
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
+// WARNING: Eigen::half is not trivially copyable and can't be used in
+// custom backprop filter kernel because of memcpy and memset in Im2col.
+#define REGISTER_CPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
// GPU definitions of both ops.
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
@@ -445,7 +1054,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 5>::operator()( \
- const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 5, int>::ConstTensor in, \
typename TTypes<T, 5, int>::Tensor out); \
template <> \
void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
@@ -523,6 +1133,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -531,7 +1145,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
@@ -539,13 +1160,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
- dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 &&
- stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
+ if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
+ dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
+ dims.stride(1) == 1 && dims.stride(2) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
- const uint64 k = out_depth;
- const uint64 n = in_depth;
+ const uint64 m = dims.batch_size * dims.input_size(0) *
+ dims.input_size(1) * dims.input_size(2);
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -567,13 +1190,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = batch;
- const uint64 k = out_depth;
- const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.batch_size;
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -597,65 +1221,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
- const bool planes_odd = (padding_planes % 2 != 0);
TensorShape compatible_input_shape;
if (rows_odd || cols_odd || planes_odd) {
// cuDNN only supports the same amount of padding on both sides.
compatible_input_shape = {
- batch,
- in_depth,
- input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd,
+ dims.batch_size,
+ dims.in_depth,
+ dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd,
};
} else {
- compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
- input_size[2]};
+ compatible_input_shape = {dims.batch_size, dims.in_depth,
+ dims.input_size(0), dims.input_size(1),
+ dims.input_size(2)};
}
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -664,20 +1282,23 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
Tensor transformed_filter;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &transformed_filter));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 5>()(
- context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ context->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
// Shape: batch, filters, z, y, x.
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
- if (out_depth > 1) {
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
+ if (dims.out_depth > 1) {
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
@@ -713,14 +1334,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = context->input(0).dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
@@ -799,10 +1420,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (rows_odd || cols_odd || planes_odd) {
Tensor in_backprop_remove_padding;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- {batch, in_depth, input_size[0],
- input_size[1], input_size[2]},
- &in_backprop_remove_padding));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {dims.batch_size, dims.in_depth, dims.input_size(0),
+ dims.input_size(1), dims.input_size(2)},
+ &in_backprop_remove_padding));
// Remove the padding for odd spatial dimensions.
functor::PadInput<GPUDevice, T, int, 5>()(
@@ -896,6 +1518,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
@@ -905,7 +1531,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
@@ -914,13 +1545,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
- dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 &&
- strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
+ if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
+ dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
+ dims.stride(1) == 1 && dims.stride(0) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = in_depth;
- const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
- const uint64 n = out_depth;
+ const uint64 m = dims.in_depth;
+ const uint64 k = dims.batch_size * dims.input_size(1) *
+ dims.input_size(2) * dims.input_size(0);
+ const uint64 n = dims.out_depth;
// The shape of output backprop is
// [batch, out_z, out_y, out_x, out_depth]
@@ -951,13 +1584,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
- const uint64 k = batch;
- const uint64 n = out_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
+ const uint64 k = dims.batch_size;
+ const uint64 n = dims.out_depth;
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
input.template flat<T>().size());
@@ -979,30 +1613,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
- bool rows_odd = (padding_rows % 2 != 0);
- bool cols_odd = (padding_cols % 2 != 0);
- bool planes_odd = (padding_planes % 2 != 0);
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
Tensor compatible_input;
if (rows_odd || cols_odd || planes_odd) {
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format_, batch,
- {{input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd}},
- in_depth),
- &compatible_input));
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format_, dims.batch_size,
+ {{dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd}},
+ dims.in_depth),
+ &compatible_input));
functor::PadInput<GPUDevice, T, int, 5>()(
context->template eigen_device<GPUDevice>(),
To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
@@ -1016,35 +1644,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X,
GetTensorDim(compatible_input, data_format_, '2'))
.set_spatial_dim(DimIndex::Y,
GetTensorDim(compatible_input, data_format_, '1'))
.set_spatial_dim(DimIndex::Z,
GetTensorDim(compatible_input, data_format_, '0'))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -1052,19 +1680,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
Tensor pre_transformed_filter_backprop;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &pre_transformed_filter_backprop));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &pre_transformed_filter_backprop));
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
- if (out_depth > 1) {
+ if (dims.out_depth > 1) {
functor::NHWCToNCHW<GPUDevice, T, 5>()(
context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
transformed_out_backprop.tensor<T, 5>());
@@ -1076,10 +1706,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
}
Tensor transformed_input;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
- compatible_input.dim_size(2),
- compatible_input.dim_size(3)};
- if (in_depth > 1) {
+ TensorShape nchw_shape = {
+ dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
+ compatible_input.dim_size(2), compatible_input.dim_size(3)};
+ if (dims.in_depth > 1) {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::value,
nchw_shape, &transformed_input));
@@ -1110,14 +1740,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ef692418d6..717a9f40a9 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -680,9 +680,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
TensorShape({filter.dim_size(3), filter.dim_size(2),
filter.dim_size(0), filter.dim_size(1)}),
&transformed_filter));
-
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_output;
@@ -731,9 +731,15 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
- &algorithms));
+ OP_REQUIRES(
+ ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown("Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -823,7 +829,8 @@ namespace functor {
extern template struct MatMulConvFunctor<GPUDevice, T>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index a1eed4e68c..83df4dce38 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -386,7 +386,8 @@ struct LaunchConvOp<GPUDevice, T> {
// filter: [x, y, z, in, out]
// t_filter: [out, in, x, y, z]
functor::TransformFilter<GPUDevice, T, int, 5>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
Tensor transformed_output;
@@ -434,10 +435,16 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
- stream->parent()),
- &algorithms));
+ OP_REQUIRES(ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown(
+ "Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
+
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -514,7 +521,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 5>::operator()( \
- const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 5, int>::ConstTensor in, \
typename TTypes<T, 5, int>::Tensor out); \
template <> \
void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index afc611f277..21d135decd 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -142,8 +142,12 @@ class ConvParameters {
template <typename T>
bool ShouldIncludeWinogradNonfusedAlgo(
se::StreamExecutor* stream_exec) const {
+ auto* dnn_support = stream_exec->AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
// Skip this check for cuDNN 7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
+ auto version = dnn_support->GetVersion();
if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index a5fa48f85e..46167db3a2 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -170,51 +170,33 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
return tensor_index;
}
-// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input,
- Dimension<3> input_dims,
- T* output) {
- Dimension<3> output_dims;
- output_dims[0] = input_dims[2];
- output_dims[1] = input_dims[1];
- output_dims[2] = input_dims[0];
-
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int output_index = index;
-
- Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
-
- Index<3> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[2];
- input_tensor_index[1] = output_tensor_index[1];
- input_tensor_index[2] = output_tensor_index[0];
-
- int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
-
- output[output_index] =
- maybe_conj<T, conjugate>::run(ldg(input + input_index));
- }
-}
-
-// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
- Dimension<3> input_dims,
- T* output) {
+// A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to
+// the given shuffle permutation in template parameters. Shuffle permutation
+// <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0,
+// 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1>
+// will populate output so that input[x][y][z] is equal to (*output)[y][z][x].
+//
+// Requires that nthreads is equal to the total number of elements in the input
+// tensor.
+template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
+__global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
+ Dimension<3> input_dims, T* output) {
Dimension<3> output_dims;
- output_dims[0] = input_dims[0];
- output_dims[1] = input_dims[2];
- output_dims[2] = input_dims[1];
-
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int output_index = index;
+ output_dims[sp0] = input_dims[0];
+ output_dims[sp1] = input_dims[1];
+ output_dims[sp2] = input_dims[2];
+
+ // Iterate over output as opposed to iterating over input for better
+ // performance. Iterating over output will generate sequential writes and
+ // random reads that performs better compared to sequential reads and random
+ // writes.
+ CUDA_1D_KERNEL_LOOP(output_index, nthreads) {
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<3> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[0];
- input_tensor_index[1] = output_tensor_index[2];
- input_tensor_index[2] = output_tensor_index[1];
+ input_tensor_index[0] = output_tensor_index[sp0];
+ input_tensor_index[1] = output_tensor_index[sp1];
+ input_tensor_index[2] = output_tensor_index[sp2];
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
@@ -439,7 +421,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
template <typename T, int NDIMS>
struct TransformFilter<GPUDevice, T, int, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d,
+ void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, int>::ConstTensor in,
typename TTypes<T, NDIMS, int>::Tensor out) {
Dimension<3> combined_dims;
@@ -450,13 +432,18 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
combined_dims[1] = in.dimension(NDIMS - 2); // input filters
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension0And2InTensor3Simple<T>
+
+ CHECK(dst_filter_format == FORMAT_OIHW)
+ << "Unsupported output layout: " << ToString(dst_filter_format);
+
+ ShuffleInTensor3Simple<T, 2, 1, 0>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
};
-// Converts Cudnn filter format back to TensorFlow filter format.
+// Converts Cudnn filter format OIHW back to TensorFlow filter format HWIO.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
template <typename T, int NDIMS>
struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
typedef GPUDevice Device;
@@ -470,7 +457,7 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
combined_dims[2] *= in.dimension(i);
}
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension0And2InTensor3Simple<T>
+ ShuffleInTensor3Simple<T, 2, 1, 0>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
@@ -937,7 +924,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
- SwapDimension1And2InTensor3Simple<T, conjugate>
+ ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, input, input_dims, output);
}
@@ -969,7 +956,7 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d);
- SwapDimension0And2InTensor3Simple<T, conjugate>
+ ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in, input_dims, out);
}
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 3a1ac73f64..87efdff789 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -628,6 +628,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.cc"],
+ deps = [
+ ":dataset",
+ ":dataset_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "optional_ops",
srcs = ["optional_ops.cc"],
hdrs = ["optional_ops.h"],
@@ -675,6 +689,19 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "model_dataset_op",
+ srcs = ["model_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "dataset_ops",
srcs = ["dataset_ops.cc"],
deps = [
@@ -708,6 +735,8 @@ tf_kernel_library(
":map_and_batch_dataset_op",
":map_dataset_op",
":map_defun_op",
+ ":model_dataset_op",
+ ":multi_device_iterator_ops",
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index a25f78c6f1..d1db1d7bec 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -117,6 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 221b5ad835..34c6c86538 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
- new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
+ new FileIterator({this, strings::StrCat(prefix, "::FileCache")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -553,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new MemoryIterator(
- {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
+ {this, strings::StrCat(prefix, "::MemoryCache")}, cache_));
}
const DataTypeVector& output_dtypes() const override {
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index ad2365b25b..0bb929b3ce 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -17,43 +17,101 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
-/* static */
-Status CapturedFunction::Create(
- const NameAttrList& func, std::vector<Tensor> captured_inputs,
- std::unique_ptr<CapturedFunction>* out_function) {
- return Create(func, std::move(captured_inputs), true, out_function);
-}
+namespace {
+
+// Simplistic implementation of the `StepStatsCollectorInterface` that only
+// cares about collecting the CPU time needed to execute a captured function.
+class SimpleStepStatsCollector : public StepStatsCollectorInterface {
+ public:
+ void IncrementProcessingTime(int64 delta) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
+ return new SimpleNodeExecStats(this);
+ }
+
+ string ReportAllocsOnResourceExhausted(const string& err) override {
+ return "";
+ }
+
+ int64 processing_time() {
+ tf_shared_lock l(mu_);
+ return processing_time_;
+ }
+
+ private:
+ class SimpleNodeExecStats : public NodeExecStatsInterface {
+ public:
+ explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
+ : step_stats_collector_(step_stats_collector) {}
+
+ void Done(const string& device) override {
+ step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
+ start_time_ns_);
+ delete this;
+ }
+
+ void RecordExecutorStarted() override {
+ start_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void RecordComputeStarted() override {}
+
+ void RecordComputeEnded() override {}
+
+ void RecordExecutorEnded() override {
+ end_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void SetMemory(OpKernelContext* ctx) override {}
+
+ void SetOutput(int slot, const Tensor* tensor) override {}
+
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
+
+ void SetScheduled(int64 nanos) override {}
+
+ private:
+ int64 start_time_ns_ = 0;
+ int64 end_time_ns_ = 0;
+ SimpleStepStatsCollector* step_stats_collector_; // Not owned.
+ };
+
+ mutex mu_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+};
+
+} // namespace
/* static */
Status CapturedFunction::Create(
- const NameAttrList& func, std::vector<Tensor> captured_inputs,
- bool use_inter_op_parallelism,
+ const NameAttrList& func, OpKernelContext* ctx, const string& argument,
std::unique_ptr<CapturedFunction>* out_function) {
- out_function->reset(new CapturedFunction(func, std::move(captured_inputs),
- use_inter_op_parallelism));
- return Status::OK();
+ return CapturedFunction::Create(func, ctx, argument, true, out_function);
}
-/* static */
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function) {
- OpInputList argument_inputs;
- TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
- std::vector<Tensor> arguments_t;
- arguments_t.reserve(argument_inputs.size());
- for (const Tensor& t : argument_inputs) {
- arguments_t.push_back(t);
- }
- return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+ OpInputList inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
+ std::vector<Tensor> arguments(inputs.begin(), inputs.end());
+ *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments),
+ use_inter_op_parallelism));
+ return Status::OK();
}
CapturedFunction::~CapturedFunction() {
@@ -358,7 +416,8 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args,
void CapturedFunction::RunAsync(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done) {
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix) {
// NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
@@ -368,13 +427,13 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
done(s);
return;
}
- auto frame =
+ OwnedArgsCallFrame* frame =
new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
- auto step_container = new ScopedStepContainer(
+ ScopedStepContainer* step_container = new ScopedStepContainer(
f_opts.step_id, [resource_mgr](const string& name) {
resource_mgr->Cleanup(name).IgnoreError();
});
@@ -389,25 +448,40 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- auto c_mgr = new CancellationManager;
+ CancellationManager* c_mgr = new CancellationManager;
f_opts.cancellation_manager = c_mgr;
-
- tf_shared_lock l(mu_);
- ctx->lib()->Run(f_opts, handle, frame,
- std::bind(
- [rets, step_container, c_mgr, frame](
- FunctionLibraryRuntime::DoneCallback done,
- // Begin unbound arguments.
- Status s) {
- delete step_container;
- delete c_mgr;
- if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
- }
- delete frame;
- done(s);
- },
- std::move(done), std::placeholders::_1));
+ std::shared_ptr<SimpleStepStatsCollector> stats_collector;
+ if (ctx->model()) {
+ stats_collector = MakeUnique<SimpleStepStatsCollector>();
+ }
+ f_opts.stats_collector = stats_collector.get();
+
+ auto callback = std::bind(
+ [rets, step_container, c_mgr, frame](
+ const FunctionLibraryRuntime::DoneCallback& done,
+ const std::shared_ptr<model::Model>& model, const string& prefix,
+ const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
+ // Begin unbound arguments.
+ Status s) {
+ delete step_container;
+ delete c_mgr;
+ if (s.ok()) {
+ s = frame->ConsumeRetvals(rets);
+ }
+ delete frame;
+ if (model) {
+ model->AddProcessingTime(prefix, stats_collector->processing_time());
+ model->RecordStart(prefix, false /* stop_output */);
+ }
+ done(s);
+ if (model) {
+ model->RecordStop(prefix, false /* start_output */);
+ }
+ },
+ std::move(done), ctx->model(), prefix, std::move(stats_collector),
+ std::placeholders::_1);
+
+ ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index e44bc78b1c..a10376bf97 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -42,27 +42,19 @@ namespace data {
// context.
class CapturedFunction {
public:
- // Creates a new instance from a list of named attributes and captured inputs.
- //
- // NOTE(mrry): The `captured_inputs` are passed by value. For
- // efficiency, you are recommended to move this argument into the call.
- static Status Create(const NameAttrList& func,
- std::vector<Tensor> captured_inputs,
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
+ static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+ const string& argument,
std::unique_ptr<CapturedFunction>* out_function);
- // Creates a new instance from a list of named attributes and captured inputs.
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
//
// If `use_inter_op_parallelism` is false, the runtime may use an executor
// that is optimized for small functions.
- static Status Create(const NameAttrList& func,
- std::vector<Tensor> captured_inputs,
- bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction>* out_function);
-
- // Creates a new instance using a list of named attributes, fetching captured
- // inputs from a context argument.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
- const string& argument,
+ const string& argument, bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function);
~CapturedFunction();
@@ -104,7 +96,8 @@ class CapturedFunction {
// in order to be able to deallocate them as early as possible.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done);
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix);
// Returns the named list of function arguments.
const NameAttrList& func() { return func_; }
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e7ac368ae3..e10833f525 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -44,5 +44,42 @@ Status MakeIteratorFromInputElement(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 234856ea39..6ec1350cd4 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -27,6 +27,16 @@ Status MakeIteratorFromInputElement(
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
+// Returns Status::OK() if `expected` and `received` types match,
+// errors::InvalidArgument otherwise.
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received);
+
+// Returns Status::OK() if `expected` and `received` shapes are compatible,
+// errors::InvalidArgument otherwise.
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received);
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index bf0aecaf3c..00884314a9 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace data {
@@ -37,14 +39,6 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
FunctionLibraryRuntime::Handle pred_handle;
OP_REQUIRES_OK(ctx,
ctx->function_library()->Instantiate(
@@ -61,9 +55,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
Node* ret_node = pred_body->ret_nodes[0];
Node* ret_input_node;
OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
+
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
if (ret_input_node->def().op() == "_Arg") {
int32 index = -1;
@@ -146,7 +141,13 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params) {}
+ : DatasetIterator<FilterDatasetBase>(params),
+ filtered_elements_(0),
+ dropped_elements_(0) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -161,6 +162,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
// `input_impl_` and `f` are thread-safe. However, if multiple
// threads enter this method, outputs may be observed in a
// non-deterministic order.
+ auto stats_aggregator = ctx->stats_aggregator();
bool matched;
do {
{
@@ -183,8 +185,34 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ dropped_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::dropped_elements"),
+ static_cast<float>((dropped_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect
+ // aggregated number of dropped elements for all the pipelines,
+ // exploit tagged_context here.
+ stats_aggregator->IncrementCounter(
+ prefix_end_, "dropped_elements", static_cast<float>(1));
+ }
}
} while (!matched);
+ // TODO(shivaniagrawal): add ratio of dropped_elements and
+ // filtered_elements as a histogram.
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ filtered_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::filtered_elements"),
+ static_cast<float>((filtered_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect aggregated
+ // number of filtered elements for all the pipelines, exploit
+ // tagged_context here.
+ stats_aggregator->IncrementCounter(prefix_end_, "filtered_elements",
+ static_cast<float>(1));
+ }
*end_of_sequence = false;
return Status::OK();
}
@@ -197,6 +225,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
else
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impls_empty"), ""));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("filtered_elements"),
+ filtered_elements_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("dropped_elements"),
+ dropped_elements_));
return Status::OK();
}
@@ -207,12 +239,19 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
input_impl_.reset();
else
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("filtered_elements"),
+ &filtered_elements_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("dropped_elements"),
+ &dropped_elements_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ int64 filtered_elements_ GUARDED_BY(mu_);
+ int64 dropped_elements_ GUARDED_BY(mu_);
+ string prefix_end_;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index e3c45ef86c..2fada22a21 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -39,18 +39,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
-
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
}
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index ac5cc1b2c1..71a36314a0 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -145,44 +145,18 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
- &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ init_func_, ctx, "init_func_other_args", &init_func));
+
std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
- &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
+ next_func_, ctx, "next_func_other_args", &next_func));
+
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_args",
+ &finalize_func));
*output =
new Dataset(ctx, std::move(init_func), std::move(next_func),
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index e4fa557598..8b417bb1c2 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -42,50 +42,19 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- // Get captured inputs for the key, reduce, and window_size functions.
- OpInputList key_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments",
- &key_func_other_argument_inputs));
- std::vector<Tensor> key_func_other_arguments;
- key_func_other_arguments.reserve(key_func_other_argument_inputs.size());
- for (const Tensor& t : key_func_other_argument_inputs) {
- key_func_other_arguments.push_back(t);
- }
- OpInputList reduce_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments",
- &reduce_func_other_argument_inputs));
- std::vector<Tensor> reduce_func_other_arguments;
- reduce_func_other_arguments.reserve(
- reduce_func_other_argument_inputs.size());
- for (const Tensor& t : reduce_func_other_argument_inputs) {
- reduce_func_other_arguments.push_back(t);
- }
- OpInputList window_size_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx,
- ctx->input_list("window_size_func_other_arguments",
- &window_size_func_other_argument_inputs));
- std::vector<Tensor> window_size_func_other_arguments;
- window_size_func_other_arguments.reserve(
- window_size_func_other_argument_inputs.size());
- for (const Tensor& t : window_size_func_other_argument_inputs) {
- window_size_func_other_arguments.push_back(t);
- }
- // TODO(mrry): Refactor CapturedFunction to share the runtime
- // state between multiple functions?
std::unique_ptr<CapturedFunction> captured_key_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- key_func_, std::move(key_func_other_arguments),
- &captured_key_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+ "key_func_other_arguments",
+ &captured_key_func));
std::unique_ptr<CapturedFunction> captured_reduce_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(reduce_func_,
- std::move(reduce_func_other_arguments),
- &captured_reduce_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+ "reduce_func_other_arguments",
+ &captured_reduce_func));
std::unique_ptr<CapturedFunction> captured_window_size_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- window_size_func_, std::move(window_size_func_other_arguments),
- &captured_window_size_func));
+ OP_REQUIRES_OK(ctx,
+ CapturedFunction::Create(window_size_func_, ctx,
+ "window_size_func_other_arguments",
+ &captured_window_size_func));
*output = new Dataset(
ctx, input, key_func_, reduce_func_, window_size_func_,
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 0768f46665..0aa802b874 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -39,14 +39,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
const Tensor* cycle_length_t;
OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()),
@@ -66,8 +58,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument("block_length must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index fe6d705eab..c0bc507ec0 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -44,43 +44,6 @@ namespace {
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
} // namespace
class IteratorResource : public ResourceBase {
@@ -403,12 +366,12 @@ class IteratorStateVariant {
}
string TypeName() const { return kIteratorVariantTypeName; }
void Encode(VariantTensorData* data) const { *data = *data_; }
- bool Decode(const VariantTensorData& data) {
+ bool Decode(VariantTensorData data) {
if (data.type_name() != TypeName()) {
return false;
}
std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
- *tensor_data = data;
+ std::swap(*tensor_data, data);
std::unique_ptr<VariantTensorDataReader> reader(
new VariantTensorDataReader(tensor_data.get()));
status_ = reader->status();
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 27c89b3661..2bbf4af664 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
+#include <atomic>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
@@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
@@ -39,7 +41,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -49,14 +50,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 batch_size;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
OP_REQUIRES(
@@ -77,7 +70,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
case 2:
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx,
+ num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
break;
@@ -92,8 +86,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, batch_size, num_parallel_calls,
drop_remainder, output_types_, output_shapes_, func_,
@@ -190,7 +184,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ : DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_) {}
~Iterator() override {
mutex_lock l(mu_);
@@ -204,6 +199,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ AddTunableParameter(ctx, "parallelism",
+ &num_parallel_calls_ /* value */, 1 /* min */,
+ port::NumSchedulableCPUs() /* max */, &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -218,12 +223,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
@@ -326,11 +333,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- num_calls_--;
- result->num_calls--;
- }
+ mutex_lock l(mu_);
+ num_calls_--;
+ result->num_calls--;
cond_var_.notify_all();
}
@@ -365,7 +370,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx.get(), std::move(input_element), return_values.get(),
[this, ctx, result, return_values, offset](Status status) {
Callback(ctx, result, return_values, offset, status);
- });
+ },
+ prefix());
},
ctx, std::move(input_element)));
}
@@ -422,11 +428,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
result->output_allocated = true;
}
- int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
- dataset()->batch_size_;
- }
-
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<BatchResult>& result,
std::vector<Tensor>* out_tensors,
@@ -475,26 +476,34 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
- new_calls.reserve(dataset()->num_parallel_calls_);
+ RecordStart(ctx.get());
+ auto stop_cleanup =
+ gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
+ new_calls.reserve(num_parallel_calls_);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_;
+ int64 max_batch_results =
+ (num_parallel_calls + dataset()->batch_size_ - 1) /
+ dataset()->batch_size_;
+ return num_calls_ >= num_parallel_calls ||
+ (batch_results_.size() > max_batch_results ||
+ (batch_results_.size() == max_batch_results &&
+ call_counter_ % dataset()->batch_size_ == 0));
+ };
while (true) {
{
mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
- batch_results_.size() > MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ == 0))) {
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < dataset()->num_parallel_calls_ &&
- (batch_results_.size() < MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ != 0))) {
+ while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
batch_results_.emplace_back(
new BatchResult(dataset()->batch_size_));
@@ -638,6 +647,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
@@ -661,7 +672,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const Eigen::ThreadPoolDevice* device_; // not owned
};
- const int graph_def_version_;
const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index af301e2b42..f112e1dc43 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -38,18 +38,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments),
- use_inter_op_parallelism_, &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index b87d61ee44..6657f2b2b3 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -81,119 +81,167 @@ class MapDefunOp : public AsyncOpKernel {
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size;
- OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done);
+ ComputeOptions* compute_opts = nullptr;
- // Inputs
- auto* args = new std::vector<Tensor>;
- auto* arg_shapes = new std::vector<TensorShape>;
+ OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
- // Create a copy because every `Compute` may have different output shapes.
- auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_);
- arg_shapes->reserve(ctx->num_inputs());
- args->reserve(ctx->num_inputs());
+ Status s = SetupOutputs(ctx, compute_opts);
+ if (!s.ok()) delete compute_opts;
+ OP_REQUIRES_OK_ASYNC(ctx, s, done);
- auto* mu = new mutex;
-
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args->push_back(ctx->input(i));
- arg_shapes->push_back(ctx->input(i).shape());
- arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- }
-
- // Outputs
- auto* output = new OpOutputList;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
-
- for (size_t i = 0; i < output_types().size(); ++i) {
- if (output_shapes_.at(i).IsFullyDefined()) {
- Tensor* out = nullptr;
- TensorShape output_shape;
- output_shapes_.at(i).AsTensorShape(&output_shape);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out),
- done);
- }
- }
-
- SetRunOptions(ctx, &opts_, false);
+ FunctionLibraryRuntime::Options opts;
+ SetRunOptions(ctx, &opts, false);
// Run loop
StatusCallback callback = std::bind(
- [](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes,
- std::vector<PartialTensorShape>* output_shapes, OpOutputList* output,
- mutex* mu, DoneCallback& done, const Status& status) {
- delete args;
- delete arg_shapes;
- delete output;
- delete output_shapes;
- delete mu;
+ [](OpKernelContext* ctx, ComputeOptions* compute_opts,
+ DoneCallback& done, const Status& status) {
+ delete compute_opts;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output_shapes, output, mu, std::move(done),
- std::placeholders::_1);
+ ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
- for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) {
- // Start from i = 1 because refcounted is initialized with refcount = 1
- refcounted->Ref();
- }
+ CancellationManager* parent_mgr = ctx->cancellation_manager();
- for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame = new MapFunctionCallFrame(
- *args, *arg_shapes, output_shapes, mu, output, this, i,
- static_cast<size_t>(batch_size));
+ for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
+ // We use a different cancellation manager each time the function is run
+ // to avoid the race condition between a function run error and other
+ // functions being cancelled as a result.
CancellationManager* c_mgr = new CancellationManager;
- opts_.cancellation_manager = c_mgr;
- ctx->function_library()->Run(
- opts_, func_handle_, call_frame,
- [call_frame, refcounted, c_mgr](const Status& func_status) {
- delete call_frame;
- delete c_mgr;
- refcounted->UpdateStatus(func_status);
- refcounted->Unref();
- });
+ CancellationToken token = parent_mgr->get_cancellation_token();
+ const bool success = parent_mgr->RegisterCallback(
+ token, [c_mgr]() { c_mgr->StartCancel(); });
+
+ opts.cancellation_manager = c_mgr;
+ if (!success) {
+ delete c_mgr;
+ refcounted->UpdateStatus(errors::Cancelled(
+ "MapDefunOp functions cancelled because parent graph cancelled"));
+ break;
+ }
+
+ auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
+
+ refcounted->Ref();
+ ctx->function_library()->Run(opts, func_handle_, call_frame,
+ [call_frame, refcounted, c_mgr, parent_mgr,
+ token](const Status& func_status) {
+ parent_mgr->DeregisterCallback(token);
+ delete c_mgr;
+ delete call_frame;
+ refcounted->UpdateStatus(func_status);
+ refcounted->Unref();
+ });
}
+
+ // Unref 1 because refcounted is initialized with refcount = 1
+ refcounted->Unref();
}
private:
FunctionLibraryRuntime::Handle func_handle_;
- FunctionLibraryRuntime::Options opts_;
std::vector<PartialTensorShape> output_shapes_;
+ struct ComputeOptions {
+ // These vary per MapDefunOp::ComputeAsync call, but must persist until
+ // all calls to the function are complete. This struct also encapsulates
+ // all the components that need to be passed to each MapFunctionCallFrame.
+
+ const std::vector<Tensor> args;
+ const std::vector<TensorShape> arg_shapes;
+ const int64 batch_size;
+
+ // Output of a compute call
+ std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
+ OpOutputList output GUARDED_BY(mu);
+ mutex mu;
+
+ // Create a copy of output_shapes because every `Compute` may expect a
+ // different output shape.
+ ComputeOptions(std::vector<Tensor> args,
+ std::vector<TensorShape> arg_shapes, int64 batch_size,
+ const std::vector<PartialTensorShape>& output_shapes_attr)
+ : args(std::move(args)),
+ arg_shapes(std::move(arg_shapes)),
+ batch_size(batch_size),
+ output_shapes(output_shapes_attr) {}
+ };
+
+ // Get inputs to Compute and check that they are valid.
+ Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
+ int64 batch_size =
+ ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+
+ std::vector<Tensor> args;
+ std::vector<TensorShape> arg_shapes;
+ args.reserve(ctx->num_inputs());
+ arg_shapes.reserve(ctx->num_inputs());
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ arg_shapes.push_back(ctx->input(i).shape());
+ arg_shapes.at(i).RemoveDim(0);
+ }
+
+ *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
+ batch_size, output_shapes_);
+ return Status::OK();
+ }
+
+ Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
+ mutex_lock l(opts->mu);
+ TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
+
+ for (size_t i = 0; i < output_types().size(); ++i) {
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, opts->batch_size);
+ TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
+ }
+ }
+ return Status::OK();
+ }
+
class MapFunctionCallFrame : public CallFrameInterface {
public:
- MapFunctionCallFrame(const std::vector<Tensor>& args,
- const std::vector<TensorShape>& arg_shapes,
- std::vector<PartialTensorShape>* output_shapes,
- mutex* output_shapes_mutex, OpOutputList* output,
- OpKernel* kernel, size_t iter, size_t batch_size)
- : args_(args),
- arg_shapes_(arg_shapes),
- output_shapes_(output_shapes),
- output_shapes_mutex_(output_shapes_mutex),
- output_(output),
- kernel_(kernel),
- iter_(iter),
- batch_size_(batch_size) {}
+ MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
+ size_t iter)
+ : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
~MapFunctionCallFrame() override {}
- size_t num_args() const override { return args_.size(); }
+ size_t num_args() const override { return compute_opts_->args.size(); }
+
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= args_.size()) {
+ if (index < 0 || index >= compute_opts_->args.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
- bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1),
- arg_shapes_.at(index));
+ bool result =
+ val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
@@ -217,36 +265,34 @@ class MapDefunOp : public AsyncOpKernel {
index);
}
{ // Locking scope
- mutex_lock l(*output_shapes_mutex_);
- if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) {
+ mutex_lock l(compute_opts_->mu);
+ if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
+ val.shape())) {
return errors::InvalidArgument(
"Mismatch in function retval shape, ", val.shape(),
- ", and expected output shape,",
- output_shapes_->at(index).DebugString(), ".");
+ ", and expected output shape, ",
+ compute_opts_->output_shapes.at(index).DebugString(), ".");
}
- if (!output_shapes_->at(index).IsFullyDefined()) {
+ if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
// Given val, we have new information about the output shape at
// this index. Store the shape and allocate the output accordingly.
- output_shapes_->at(index) = val.shape();
+ compute_opts_->output_shapes.at(index) = val.shape();
Tensor* out = nullptr;
TensorShape actual_shape = val.shape();
- actual_shape.InsertDim(0, batch_size_);
- TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out));
+ actual_shape.InsertDim(0, compute_opts_->batch_size);
+ TF_RETURN_IF_ERROR(
+ compute_opts_->output.allocate(index, actual_shape, &out));
}
+ return batch_util::CopyElementToSlice(
+ val, (compute_opts_->output)[index], iter_);
}
- return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
}
private:
- const std::vector<Tensor>& args_;
- const std::vector<TensorShape>& arg_shapes_;
- std::vector<PartialTensorShape>* output_shapes_;
- mutex* output_shapes_mutex_;
- OpOutputList* output_;
+ ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
- const size_t batch_size_;
};
};
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
new file mode 100644
index 0000000000..9aa505f4f1
--- /dev/null
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -0,0 +1,183 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
+
+class ModelDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ModelDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Model")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override { return "ModelDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ model_(std::make_shared<model::Model>()) {}
+
+ ~Iterator() override {
+ // Signal the optimize thread to terminate it. We will then join that
+ // thread when we delete `this->optimize_thread_`.
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ cond_var_.notify_all();
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return dataset()->input_->MakeIterator(&ctx_with_model, prefix(),
+ &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx));
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return input_impl_->GetNext(&ctx_with_model, out_tensors,
+ end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ IteratorContext::Params CreateParams(IteratorContext* ctx) {
+ IteratorContext::Params params = ctx->params();
+ params.model = model_;
+ return params;
+ }
+
+ private:
+ Status EnsureOptimizeThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!optimize_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ optimize_thread_.reset(ctx->env()->StartThread(
+ {}, "optimize_thread",
+ [this, new_ctx]() { OptimizeThread(new_ctx); }));
+ }
+ return Status::OK();
+ }
+
+ void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
+ int64 last_optimization_ms = 0;
+ int64 optimization_period_ms = 10;
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ last_optimization_ms + optimization_period_ms >=
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros) {
+ cond_var_.wait_for(
+ l, std::chrono::milliseconds(
+ last_optimization_ms + optimization_period_ms -
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros));
+ }
+ if (cancelled_) return;
+ }
+ model_->Optimize(port::NumSchedulableCPUs());
+ // Exponentially increase the period of running the optimization
+ // until a threshold is reached.
+ if (optimization_period_ms < kOptimizationPeriodThresholdMs) {
+ if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) {
+ optimization_period_ms <<= 1;
+ } else {
+ optimization_period_ms = kOptimizationPeriodThresholdMs;
+ }
+ }
+ last_optimization_ms =
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ }
+ }
+
+ mutex mu_;
+ condition_variable cond_var_;
+ std::shared_ptr<model::Model> model_;
+ std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
+ ModelDatasetOp);
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
new file mode 100644
index 0000000000..5f143967d9
--- /dev/null
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -0,0 +1,633 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct HostBufferElement {
+ Status status;
+ bool end_of_sequence;
+ std::vector<Tensor> value;
+};
+
+using MultiDeviceIteratorCallback =
+ std::function<void(const HostBufferElement&)>;
+
+class MultiDeviceIterator : public ResourceBase {
+ public:
+ MultiDeviceIterator(const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ const std::vector<string>& devices,
+ std::unique_ptr<FunctionLibraryDefinition> flib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ FunctionLibraryRuntime* lib)
+ : output_types_(output_types),
+ output_shapes_(output_shapes),
+ devices_(devices),
+ flib_def_(std::move(flib_def)),
+ pflr_(std::move(pflr)),
+ lib_(lib) {
+ DCHECK(lib_ != nullptr);
+ }
+
+ string DebugString() override {
+ return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
+ " devices");
+ }
+
+ Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
+ int64* incarnation_id) {
+ if (iterator) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, iterator->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
+ }
+
+ mutex_lock l(mu_);
+ if (multi_device_buffer_) {
+ multi_device_buffer_->Reset();
+ }
+
+ ++incarnation_id_;
+ *incarnation_id = incarnation_id_;
+
+ multi_device_buffer_.reset(
+ new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_,
+ std::move(iterator)));
+ return Status::OK();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ if (lib_ != nullptr) {
+ ctx->set_lib(lib_);
+ }
+ tf_shared_lock l(mu_);
+ multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
+ std::move(callback));
+ }
+
+ const DataTypeVector& output_types() const { return output_types_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ std::shared_ptr<const FunctionLibraryDefinition> function_library() {
+ tf_shared_lock l(mu_);
+ return lib_def_;
+ }
+
+ FunctionLibraryRuntime* const lib() {
+ tf_shared_lock l(mu_);
+ return lib_;
+ }
+
+ private:
+ // A private class that uses a background thread to keep a per device buffer
+ // full.
+ class MultiDeviceBuffer {
+ public:
+ MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
+ std::unique_ptr<IteratorBase> host_iterator)
+ : buffer_(size),
+ size_(size),
+ max_buffer_size_(max_buffer_size),
+ incarnation_id_(incarnation_id),
+ host_iterator_(std::move(host_iterator)) {}
+
+ ~MultiDeviceBuffer() {
+ {
+ mutex_lock l(mu_);
+ if (!background_thread_started_) return;
+ }
+ Reset();
+ }
+
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (background_thread_finished_) {
+ return;
+ }
+
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
+
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
+ }
+ RunPendingCallbacks();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ HostBufferElement elem;
+ if (incarnation_id_ != incarnation_id) {
+ elem.status = errors::InvalidArgument("Invalid incarnation id");
+ callback(elem);
+ return;
+ }
+
+ bool produced_output = false;
+ {
+ mutex_lock l(mu_);
+ if (cancelled_) {
+ elem.status = errors::Cancelled("Cancelled Multidevice iterator");
+ callback(elem);
+ return;
+ }
+
+ EnsureBackgroundThreadStarted(ctx);
+
+ if (!buffer_[shard_num].data.empty()) {
+ produced_output = true;
+ std::swap(elem, buffer_[shard_num].data.front());
+ buffer_[shard_num].data.pop_front();
+ // Wake up background thread if it is blocked on this element.
+ if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
+ buffer_[shard_num].cond_var.notify_all();
+ }
+ } else {
+ if (background_thread_finished_) {
+ produced_output = true;
+ elem.end_of_sequence = true;
+ } else {
+ buffer_[shard_num].callbacks.push_back(std::move(callback));
+ callback = nullptr;
+ }
+ }
+ }
+
+ if (produced_output) {
+ callback(elem);
+ }
+ }
+
+ private:
+ void EnsureBackgroundThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!background_thread_) {
+ background_thread_.reset(ctx->env()->StartThread(
+ {}, "multi_device_iterator_background_thread",
+ std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
+ this, new IteratorContext(*ctx))));
+ }
+ }
+
+ void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
+ // Run all remaining callbacks.
+ std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
+ std::vector<HostBufferElement> cancellation_elements;
+ {
+ mutex_lock l(mu_);
+
+ for (int i = 0; i < size_; ++i) {
+ while (!buffer_[i].callbacks.empty()) {
+ if (buffer_[i].data.empty()) {
+ HostBufferElement elem;
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ cancellation_elements.push_back(std::move(elem));
+ } else {
+ cancellation_elements.push_back(
+ std::move(buffer_[i].data.front()));
+ buffer_[i].data.pop_front();
+ }
+ cancellation_callbacks.push_back(
+ std::move(buffer_[i].callbacks.front()));
+ buffer_[i].callbacks.pop_front();
+ }
+ }
+ }
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_elements[i]);
+ }
+ }
+
+ void BackgroundThread(IteratorContext* ctx) {
+ {
+ mutex_lock l(mu_);
+ background_thread_started_ = true;
+ }
+ std::unique_ptr<IteratorContext> cleanup(ctx);
+ int shard_to_fetch = 0;
+ while (true) {
+ HostBufferElement elem;
+ MultiDeviceIteratorCallback callback = nullptr;
+ bool end_of_iterator = false;
+
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
+ buffer_[shard_to_fetch].cond_var.wait(l);
+ }
+
+ if (cancelled_) {
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ return;
+ }
+ }
+
+ elem.status =
+ host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence);
+
+ if (elem.status.ok() && elem.end_of_sequence) {
+ end_of_iterator = true;
+ }
+
+ {
+ mutex_lock l(mu_);
+ // Try to find a callback, else just push stuff into buffer.
+ if (!buffer_[shard_to_fetch].callbacks.empty()) {
+ callback = buffer_[shard_to_fetch].callbacks.front();
+ buffer_[shard_to_fetch].callbacks.pop_front();
+ } else {
+ buffer_[shard_to_fetch].data.push_back(std::move(elem));
+ elem = HostBufferElement();
+ }
+ }
+
+ if (callback) {
+ (*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
+ }
+
+ // Finish off the thread if we reach the end of the iterator. Runs
+ // pending callbacks.
+ if (end_of_iterator) {
+ {
+ mutex_lock l(mu_);
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ }
+ RunPendingCallbacks();
+ return;
+ }
+ shard_to_fetch = (shard_to_fetch + 1) % size_;
+ }
+ }
+
+ struct HostBuffer {
+ condition_variable cond_var;
+ std::deque<HostBufferElement> data;
+ std::deque<MultiDeviceIteratorCallback> callbacks;
+ };
+
+ mutex mu_;
+ std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
+ bool background_thread_finished_ GUARDED_BY(mu_) = false;
+ bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
+
+ std::vector<HostBuffer> buffer_;
+
+ const size_t size_;
+ const int64 max_buffer_size_;
+ const int64 incarnation_id_;
+ const std::unique_ptr<IteratorBase> host_iterator_;
+ };
+
+ mutex mu_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<string> devices_;
+ const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ FunctionLibraryRuntime* const lib_ = nullptr; // not owned.
+ std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
+
+ int64 incarnation_id_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
+};
+
+// Just creates a MultiDeviceIterator and returns it.
+class MultiDeviceIteratorHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
+ }
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel.
+ ~MultiDeviceIteratorHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MultiDeviceIterator>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ OP_REQUIRES_OK(context, context->function_library()->Clone(
+ &flib_def, &pflr, &lib));
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(
+ context,
+ mgr->LookupOrCreate<MultiDeviceIterator>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MultiDeviceIterator(
+ output_types_, output_shapes_, devices_,
+ std::move(flib_def), std::move(pflr), lib);
+ return Status::OK();
+ }));
+
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MultiDeviceIterator>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MultiDeviceIterator* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+ string name_;
+ string container_;
+ std::vector<string> devices_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
+ MultiDeviceIteratorHandleOp);
+
+// Calls init on the MultiDeviceIterator.
+class MultiDeviceIteratorInitOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* tensor_max_buffer_size;
+ OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
+ int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
+
+ DatasetBase* dataset;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
+ core::ScopedUnref unref(resource);
+
+ std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(resource->lib());
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
+ int64 incarnation_id;
+ OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
+ &incarnation_id));
+ Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
+ tensor_incarnation_id.scalar<int64>()() = incarnation_id;
+ OP_REQUIRES_OK(ctx,
+ ctx->set_output("incarnation_id", tensor_incarnation_id));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
+ MultiDeviceIteratorInitOp);
+
+// Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
+// TODO(rohanj): Implement using BackgroundWorker that Derek built?
+class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
+ public:
+ explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ thread_pool_(new thread::ThreadPool(
+ ctx->env(), ThreadOptions(),
+ strings::StrCat("multi_device_iterator_get_next_thread_",
+ SanitizeThreadSuffix(name())),
+ 1 /* num_threads */, false /* low_latency_hint */)) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor* tensor_shard_num;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
+ int32 shard_num = tensor_shard_num->scalar<int32>()();
+
+ const Tensor* tensor_incarnation_id;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
+ int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
+
+ MultiDeviceIterator* iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
+ thread_pool_->Schedule(std::bind(
+ [ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
+
+ MultiDeviceIteratorCallback callback = std::bind(
+ [ctx](const HostBufferElement& elem, DoneCallback done) {
+ // iterator->Unref();
+ Status s = elem.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (elem.end_of_sequence) {
+ ctx->SetStatus(errors::OutOfRange("End of sequence"));
+ } else {
+ for (int i = 0; i < elem.value.size(); ++i) {
+ ctx->set_output(i, elem.value[i]);
+ }
+ }
+ done();
+ },
+ std::placeholders::_1, std::move(done));
+
+ iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
+ callback);
+ iterator->Unref();
+ },
+ std::move(done)));
+ }
+
+ private:
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
+ MultiDeviceIteratorGetNextFromShardOp);
+
+class MultiDeviceIteratorToStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& resource_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
+ errors::InvalidArgument("resource_handle must be a scalar"));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ resource->Unref();
+
+ Tensor* string_handle_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &string_handle_t));
+ string_handle_t->scalar<string>()() =
+ resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorToStringHandleOp);
+
+class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES(
+ ctx,
+ output_types_.empty() || output_shapes_.empty() ||
+ output_types_.size() == output_shapes_.size(),
+ errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
+ "are set, they must have the same length."));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& string_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
+ errors::InvalidArgument("string_handle must be a scalar"));
+
+ ResourceHandle resource_handle;
+ OP_REQUIRES(
+ ctx,
+ resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
+ errors::InvalidArgument(
+ "Could not parse string_handle as a valid ResourceHandle"));
+
+ OP_REQUIRES(
+ ctx, resource_handle.device() == ctx->device()->attributes().name(),
+ errors::InvalidArgument("Attempted create an iterator on device \"",
+ ctx->device()->attributes().name(),
+ "\" from handle defined on device \"",
+ resource_handle.device(), "\""));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
+ core::ScopedUnref unref_iterator(resource);
+ if (!output_types_.empty()) {
+ OP_REQUIRES_OK(ctx,
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ }
+ if (!output_shapes_.empty()) {
+ OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
+ resource->output_shapes()));
+ }
+
+ Tensor* resource_handle_t;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
+ resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
+ }
+
+ private:
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorFromStringHandleOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index b372d31a93..2ab5c83082 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -108,11 +108,8 @@ class OptionalFromValueOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
OpInputList components_input;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
- std::vector<Tensor> components;
- components.reserve(components_input.size());
- for (const Tensor& component_t : components_input) {
- components.push_back(component_t);
- }
+ std::vector<Tensor> components(components_input.begin(),
+ components_input.end());
OP_REQUIRES_OK(
ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
}
@@ -216,6 +213,14 @@ static Status OptionalDeviceCopy(
std::vector<Tensor> to_values;
to_values.reserve(from_values.size());
for (const Tensor& t : from_values) {
+ if (t.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ }
+ }
+ for (const Tensor& t : from_values) {
if (DMAHelper::CanUseDMA(&t)) {
Tensor tmp(t.dtype());
TF_RETURN_IF_ERROR(copy(t, &tmp));
@@ -231,10 +236,9 @@ static Status OptionalDeviceCopy(
return Status::OK();
}
-#define REGISTER_OPTIONAL_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
- OptionalDeviceCopy)
+#define REGISTER_OPTIONAL_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ OptionalVariant, DIRECTION, OptionalDeviceCopy)
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index fd0e6c4cd0..7b01c3b4e0 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -207,6 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 640f1565b7..2e6e0465f7 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <atomic>
#include <deque>
#include <utility>
@@ -44,14 +45,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -83,8 +76,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- interleave_func_, std::move(other_arguments), &captured_func));
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, interleave_func_, std::move(captured_func),
@@ -252,6 +245,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -351,11 +345,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (must_wait_for_input) {
// Wait for elements to become available.
+ RecordStop(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
+ RecordStart(ctx);
}
}
return errors::Cancelled(
@@ -484,10 +480,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("worker_threads_running"))) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
@@ -583,10 +579,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -601,7 +597,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
// Produces elements into the worker's output buffers.
- void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
+ void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+ const int64 thread_index) {
// Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
//
// 1. Any local state that may need to be checkpointed should be kept
@@ -622,10 +619,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- std::unique_ptr<IteratorContext> ctx(ctx_ptr);
- auto cleanup = gtl::MakeCleanup([this, thread_index] {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
+ RecordStop(ctx.get());
});
bool make_new_iterator;
{
@@ -651,9 +649,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// 1. Build a new iterator or use the existing one.
if (make_new_iterator) {
// 1a. Get new input tensors or use the exiting ones.
-
bool read_new_input;
-
{
tf_shared_lock l(ckpt_mu_);
// worker_thread_states_[thread_index].input will be non-empty
@@ -665,7 +661,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -715,7 +713,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -764,7 +764,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
@@ -1093,9 +1095,6 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -1111,7 +1110,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
@@ -1119,16 +1118,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
errors::InvalidArgument(
"num_parallel_calls must less than or equal to cycle_length."));
- // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`.
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- interleave_func_, std::move(other_arguments), &captured_func));
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, interleave_func_,
std::move(captured_func), cycle_length, block_length,
@@ -1221,6 +1214,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
@@ -1241,6 +1235,16 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ AddTunableParameter(ctx, "parallelism",
+ &num_parallel_calls_ /* value */, 1 /* min */,
+ dataset()->cycle_length_ /* max */, &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
+ AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -1256,7 +1260,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
if (!invocation_results_.empty()) {
std::swap(result, invocation_results_.front());
@@ -1265,9 +1271,11 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
+ RecordStop(ctx);
result->notification.WaitForNotification();
+ RecordStart(ctx);
} while (result->skip);
if (result->status.ok()) {
@@ -1391,6 +1399,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
LOCKS_EXCLUDED(mu_) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
for (auto& result : results) {
if (!end_of_input) {
@@ -1408,56 +1418,66 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Release the ownership of the cycle element iterator, closing the
// iterator if end of input was encountered.
- {
- if (end_of_input) {
- current_elements_[cycle_index].reset();
- }
- mutex_lock l(mu_);
- element_in_use_[cycle_index] = false;
- num_calls_--;
- if (end_of_input) {
- args_list_[cycle_index].clear();
- num_open_--;
- }
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
}
cond_var_.notify_all();
}
- int64 MaxInvocationResults() {
- return dataset()->cycle_length_ * dataset()->block_length_;
- }
-
// Method responsible for 1) creating iterators out of input elements, 2)
// determining the order in which elements are fetched from the iterators,
// and 3) scheduling the fetching of the elements to a threadpool.
//
// This method runs in the `runner_thread` background thread.
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ return element_in_use_[cycle_index_] ||
+ num_calls_ >= num_parallel_calls_ ||
+ invocation_results_.size() >=
+ dataset()->cycle_length_ * dataset()->block_length_;
+ };
while (true) {
- {
- mutex_lock l(mu_);
- // Wait until this thread is cancelled, the end of input has been
- // reached, or the cycle element at the `cycle_index_` position is
- // not in use and there is space in the `invocation_results_` queue.
- while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
- (element_in_use_[cycle_index_] ||
- num_calls_ >= dataset()->num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
- cond_var_.wait(l);
- }
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
+ RecordStop(ctx.get());
+ cond_var_.wait(l);
+ RecordStart(ctx.get());
+ }
- if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
- return;
- }
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
- while (!element_in_use_[cycle_index_] &&
- (!end_of_input_ || num_open_ > 0) &&
- num_calls_ < dataset()->num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
- if (!current_elements_[cycle_index_]) {
- // Try to create a new iterator from the next input element.
- Status status = input_impl_->GetNext(
- ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ while ((!end_of_input_ || num_open_ > 0) && !busy()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
@@ -1466,39 +1486,25 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
result->notification.Notify();
break;
}
- if (!end_of_input_) {
- Status status = MakeIteratorFromInputElement(
- ctx.get(), args_list_[cycle_index_], cycle_index_,
- dataset()->captured_func_.get(), prefix(),
- &current_elements_[cycle_index_]);
- if (!status.ok()) {
- invocation_results_.emplace_back(new InvocationResult());
- std::shared_ptr<InvocationResult>& result =
- invocation_results_.back();
- result->status.Update(status);
- result->notification.Notify();
- break;
- }
- ++num_open_;
- }
+ ++num_open_;
}
- if (current_elements_[cycle_index_]) {
- // Pre-allocate invocation results for outputs to be fetched
- // and then fetch the outputs asynchronously.
- std::vector<std::shared_ptr<InvocationResult>> results;
- results.reserve(dataset()->block_length_);
- for (int i = 0; i < dataset()->block_length_; ++i) {
- invocation_results_.emplace_back(new InvocationResult());
- results.push_back(invocation_results_.back());
- }
- num_calls_++;
- element_in_use_[cycle_index_] = true;
- thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
- ctx, cycle_index_,
- std::move(results)));
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
}
- cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
}
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
cond_var_.notify_all();
}
@@ -1601,6 +1607,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// and there are elements left to be fetched.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
+
// Iterator for input elements.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index a0cb179eb8..6abe6c8338 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -44,25 +44,17 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int32 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments),
- use_inter_op_parallelism_, &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
@@ -97,31 +89,26 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- ParallelMapIteratorFunction map_func;
- if (use_inter_op_parallelism_) {
- map_func = [this](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- };
- } else {
- map_func = [this](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())(std::bind(
- [this, ctx, result](std::vector<Tensor>& input_element,
- StatusCallback& done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- },
- std::move(input_element), std::move(done)));
+ const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
+ ParallelMapIteratorFunction map_func =
+ [this, new_prefix](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done), new_prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](
+ IteratorContext* ctx, std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
+ result, std::move(done)));
};
}
- return NewParallelMapIterator(
- {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
- std::move(init_func), std::move(map_func), num_parallel_calls_);
+ return NewParallelMapIterator({this, new_prefix}, input_,
+ std::move(init_func), std::move(map_func),
+ num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 4ae742aaaf..ee20249bfe 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -14,11 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+#include <atomic>
#include <deque>
#include <functional>
#include <utility>
#include <vector>
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
namespace tensorflow {
namespace data {
namespace {
@@ -37,11 +41,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
num_parallel_calls_(num_parallel_calls) {}
~ParallelMapIterator() override {
- // TODO(mrry): Replace this cancellation logic with a
- // CancellationManager. The syntax would be more heavyweight,
- // but it would be possible to thread a cancellation manager
- // through the IteratorContext to upstream,
- // potentially-blocking iterators, when we add these.
mutex_lock l(mu_);
// Cancel the runner thread.
cancelled_ = true;
@@ -53,6 +52,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
+ // use it here for the maximum.
+ AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
+ 1 /* min */, port::NumSchedulableCPUs() /* max */,
+ &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
@@ -68,13 +78,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
+ RecordStop(ctx);
result->notification.WaitForNotification();
+ RecordStart(ctx);
return ProcessResult(result, out_tensors, end_of_sequence);
}
@@ -87,9 +101,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("invocation_results.size"),
- invocation_results_.size()));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
+ invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
std::shared_ptr<InvocationResult> result = invocation_results_[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
@@ -176,9 +189,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
{
mutex_lock l(mu_);
num_calls_--;
+ cond_var_.notify_all();
}
result->notification.Notify();
- cond_var_.notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
@@ -193,9 +206,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in
- // `result->return_values`, and notify `result->notification` to unblock
- // a consumer.
+ // Call `func_(input_element)`, store the result in `result->return_values`,
+ // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
@@ -205,8 +217,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
std::move(done));
}
- int64 MaxInvocationResults() { return num_parallel_calls_; }
-
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
@@ -226,27 +236,33 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
new_calls.reserve(num_parallel_calls_);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_;
+ return num_calls_ >= num_parallel_calls ||
+ invocation_results_.size() >= num_parallel_calls;
+ };
while (true) {
{
mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
+ while (!busy()) {
invocation_results_.emplace_back(new InvocationResult());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
@@ -295,7 +311,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
- const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
mutex mu_;
// Used for coordination between the main thread and the runner thread. In
@@ -304,6 +319,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
// parallelism and there are slots available in the `invocation_results_`
// buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 0cf5db017b..c28c06da62 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -87,11 +87,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
"Expected len(dense_defaults) == len(dense_keys) but got: ",
dense_default_tensors.size(), " vs. ", dense_keys_.size()));
- std::vector<Tensor> dense_defaults;
- dense_defaults.reserve(dense_default_tensors.size());
- for (const Tensor& dense_default_t : dense_default_tensors) {
- dense_defaults.push_back(dense_default_t);
- }
+ std::vector<Tensor> dense_defaults(dense_default_tensors.begin(),
+ dense_default_tensors.end());
for (int d = 0; d < dense_keys_.size(); ++d) {
const Tensor& def_value = dense_defaults[d];
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index 533d0bd5d2..da357339c9 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -26,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
}
}
+namespace {
+// Determines what strategy to use for increasing the buffer size limit. For
+// limits less than the threshold, an exponential increase is used, while for
+// limits greater than or equal to the threshold, a linear increase is used.
+size_t kBufferLimitThreshold = 2048;
+} // namespace
+
void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
switch (mode_) {
case Mode::kDisabled:
@@ -37,7 +44,11 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
return;
case Mode::kDownswing:
if (current_buffer_size == 0) {
- buffer_limit_ *= 2; // Increase the buffer size.
+ if (buffer_limit_ >= kBufferLimitThreshold) {
+ buffer_limit_ += kBufferLimitThreshold;
+ } else {
+ buffer_limit_ *= 2;
+ }
mode_ = Mode::kUpswing;
}
return;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index ad7d5eb3ff..754ed772db 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
@@ -102,16 +103,18 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ auto stats_aggregator = ctx->stats_aggregator();
{
mutex_lock l(mu_);
- auto stats_aggregator = ctx->stats_aggregator();
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
if (cancelled_) {
@@ -133,6 +136,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
+ if (stats_aggregator) {
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
+ }
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
@@ -216,6 +227,12 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
strings::StrCat(prefix_end_, "::buffer_utilization"),
{static_cast<float>(buffer_.size()) /
static_cast<float>(auto_tuner_.buffer_limit())});
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
}
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
@@ -239,10 +256,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
- prefetch_thread_.reset(
- ctx->env()->StartThread({}, "prefetch_thread",
- std::bind(&Iterator::PrefetchThread, this,
- new IteratorContext(*ctx))));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ prefetch_thread_.reset(ctx->env()->StartThread(
+ {}, "prefetch_thread",
+ [this, new_ctx]() { PrefetchThread(new_ctx); }));
}
return Status::OK();
}
@@ -251,8 +268,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
// buffer.
//
// It owns the iterator context passed to it.
- void PrefetchThread(IteratorContext* ctx) {
- std::unique_ptr<IteratorContext> cleanup(ctx);
+ void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
while (true) {
std::vector<Tensor> value;
@@ -260,7 +278,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
@@ -277,8 +297,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
bool end_of_sequence;
BufferElement buffer_element;
- buffer_element.status =
- input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence);
+ buffer_element.status = input_impl_->GetNext(
+ ctx.get(), &buffer_element.value, &end_of_sequence);
if (buffer_element.status.ok() && end_of_sequence) {
mutex_lock l(mu_);
prefetch_thread_finished_ = true;
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 6e515d6cc8..dbe31f37b8 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -45,23 +45,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
OpInputList initial_state_inputs;
OP_REQUIRES_OK(ctx,
ctx->input_list("initial_state", &initial_state_inputs));
- std::vector<Tensor> initial_state;
- initial_state.reserve(initial_state_inputs.size());
- for (const Tensor& t : initial_state_inputs) {
- initial_state.push_back(t);
- }
-
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
+ std::vector<Tensor> initial_state(initial_state_inputs.begin(),
+ initial_state_inputs.end());
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(initial_state),
std::move(captured_func), state_types_, output_types_,
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index e1cefd23d8..ca4ea25b89 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -33,11 +33,7 @@ class TensorDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
// TODO(mrry): Validate that the shapes of the "components" tensors match
// the "shapes" attr.;
- std::vector<Tensor> components;
- components.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- components.push_back(t);
- }
+ std::vector<Tensor> components(inputs.begin(), inputs.end());
*output = new Dataset(ctx, std::move(components));
}
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 3975086841..ac44623ce2 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -33,22 +33,44 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
- OP_REQUIRES_OK(
- ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
- *output = new Dataset(ctx, window_size, input);
+ int64 window_shift = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "shift", &window_shift));
+ OP_REQUIRES(
+ ctx, window_shift > 0,
+ errors::InvalidArgument("Window shift must be greater than zero."));
+
+ int64 window_stride = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "stride", &window_stride));
+ OP_REQUIRES(
+ ctx, window_stride > 0,
+ errors::InvalidArgument("Window stride must be greater than zero."));
+
+ bool drop_remainder;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
+
+ *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
+ drop_remainder);
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size,
+ int64 window_shift, int64 window_stride, bool drop_remainder)
: DatasetBase(DatasetContext(ctx)),
+ input_(input),
window_size_(window_size),
- input_(input) {
+ window_shift_(window_shift),
+ window_stride_(window_stride),
+ drop_remainder_(drop_remainder) {
input_->Ref();
}
@@ -72,7 +94,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
}
string DebugString() const override {
- return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset");
+ return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_,
+ window_stride_, drop_remainder_, ")::Dataset");
}
protected:
@@ -81,10 +104,19 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- Node* window_size = nullptr;
- TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
+ Node* window_size_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
+ Node* window_shift_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
+ Node* window_stride_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
+ Node* drop_remainder_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, window_size}, output));
+ b->AddDataset(this,
+ {input_graph_node, window_size_node, window_shift_node,
+ window_stride_node, drop_remainder_node},
+ output));
return Status::OK();
}
@@ -101,37 +133,79 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- // Each row of `window_elements` is a tuple of tensors from the
- // input iterator.
+ const int64 window_size = dataset()->window_size_;
+ const int64 window_shift = dataset()->window_shift_;
+ const int64 window_stride = dataset()->window_stride_;
std::vector<std::vector<Tensor>> window_elements;
+ Status status = Status::OK();
{
mutex_lock l(mu_);
- if (!input_impl_) {
+ if (!input_impl_ && buffer_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
- window_elements.reserve(dataset()->window_size_);
- *end_of_sequence = false;
- for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence;
- ++i) {
- std::vector<Tensor> window_element_tuple;
- TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple,
- end_of_sequence));
- if (!*end_of_sequence) {
- window_elements.emplace_back(std::move(window_element_tuple));
- } else {
- input_impl_.reset();
+
+ // Add elements to the buffer.
+ size_t target_size = TargetBufferSize(window_size, window_stride);
+ if (input_impl_) {
+ *end_of_sequence = false;
+ for (size_t i = buffer_.size();
+ i < target_size && !*end_of_sequence; ++i) {
+ std::vector<Tensor> element;
+ Status status =
+ input_impl_->GetNext(ctx, &element, end_of_sequence);
+ if (!*end_of_sequence) {
+ buffer_.emplace_back(std::move(element), status);
+ } else {
+ input_impl_.reset();
+ }
}
}
+
+ // If there are not enough elements and `drop_remainder` is set, we do
+ // not wish to return a smaller window.
+ if (buffer_.empty() ||
+ (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
+ DCHECK(*end_of_sequence);
+ return Status::OK();
+ }
+
+ int num_elements = 1 + (buffer_.size() - 1) / window_stride;
+ window_elements.reserve(num_elements);
+ for (size_t i = 0; i < num_elements; ++i) {
+ status.Update(buffer_[window_stride * i].status);
+ if (!status.ok()) {
+ break;
+ }
+ window_elements.emplace_back(buffer_[window_stride * i].result);
+ }
+
+ // Shift the window, discarding elements if necessary.
+ int buffer_size = buffer_.size();
+ if (window_shift >= buffer_size) {
+ for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
+ bool end_of_input;
+ std::vector<Tensor> element;
+ // Ignore non-error status of discarded elements.
+ input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
+ if (end_of_input) {
+ input_impl_.reset();
+ }
+ }
+ buffer_.clear();
+ } else {
+ buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
+ }
}
- if (window_elements.empty()) {
- DCHECK(*end_of_sequence);
- return Status::OK();
+ if (!status.ok()) {
+ return status;
}
+ // Construct output tensors.
const size_t num_tuple_components = window_elements[0].size();
const int64 num_window_elements = window_elements.size();
+ *end_of_sequence = false;
for (size_t idx = 0; idx < num_tuple_components; ++idx) {
DatasetBase* window_dataset;
std::vector<std::vector<Tensor>> window_component_elements;
@@ -154,7 +228,6 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
&out_tensors->back()));
}
- *end_of_sequence = false;
return Status::OK();
}
@@ -167,6 +240,20 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
}
+ // Save buffer.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
+ buffer_.size()));
+ for (int64 i = 0; i < buffer_.size(); i++) {
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(strings::StrCat("buffer[", i, "].size"),
+ buffer_[i].result.size()));
+ for (int64 j = 0; j < buffer_[i].result.size(); j++) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
@@ -178,22 +265,92 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
input_impl_.reset();
}
+ // Restore buffer.
+ int64 buffer_size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
+ buffer_.resize(buffer_size);
+ for (int64 i = 0; i < buffer_size; i++) {
+ int64 vector_size;
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat("buffer[", i, "].size"), &vector_size));
+ buffer_[i].result.resize(vector_size);
+ for (int64 j = 0; j < vector_size; j++) {
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ &buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
private:
+ struct InvocationResult {
+ InvocationResult() = default;
+ InvocationResult(std::vector<Tensor>&& result, const Status& status)
+ : result(result), status(status) {}
+
+ std::vector<Tensor> result;
+ Status status;
+ };
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].error_message"));
+ }
+
+ size_t TargetBufferSize(int64 window_size, int64 window_stride) {
+ return (window_size - 1) * window_stride + 1;
+ }
+
mutex mu_;
+ std::deque<InvocationResult> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
- const int64 window_size_;
const DatasetBase* const input_;
+ const int64 window_size_;
+ const int64 window_shift_;
+ const int64 window_stride_;
+ const bool drop_remainder_;
};
};
REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
WindowDatasetOp);
-
} // namespace
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index b4dcf0a74b..ae451be7e2 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel {
errors::InvalidArgument(
"Number of channels must be 1, 3 or 4, was ", channels_));
- OP_REQUIRES(context, width > 0 && header_size >= 0,
+ OP_REQUIRES(context, width > 0,
errors::InvalidArgument("Width must be positive"));
+ OP_REQUIRES(context, height != 0,
+ errors::InvalidArgument("Height must be nonzero"));
OP_REQUIRES(context, header_size >= 0,
errors::InvalidArgument("header size must be nonnegative"));
@@ -108,8 +110,7 @@ class DecodeBmpOp : public OpKernel {
const int32 abs_height = abs(height);
// there may be padding bytes when the width is not a multiple of 4 bytes
- // 8 * channels == bits per pixel
- const int row_size = (8 * channels_ * width + 31) / 32 * 4;
+ const int row_size = (channels_ * width + 3) / 4 * 4;
const int64 last_pixel_offset = static_cast<int64>(header_size) +
(abs_height - 1) * row_size +
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 3eed847c16..6bfb5bd5bc 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -61,6 +61,9 @@ class DecodeCSVOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
for (int i = 0; i < record_defaults.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 2a25459194..76afd6f18c 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/util_ptx.cuh"
+#include "third_party/cub/util_ptx.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 862a97723f..e7882acc80 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -35,10 +35,10 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_radix_sort.cuh"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
-#include "external/cub_archive/cub/thread/thread_operators.cuh"
+#include "third_party/cub/device/device_radix_sort.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/iterator/constant_input_iterator.cuh"
+#include "third_party/cub/thread/thread_operators.cuh"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index 27918b410b..8edf7d4a2c 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -59,12 +59,12 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const array<
typename internal::traits<OutputBackward>::Index, 5>,
const TensorReverseOp<const Eigen::array<bool, 5>,
- const Kernel> > > >,
+ const Kernel>>>>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > >,
+ const OutputBackward>>>>,
TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
internal::traits<OutputBackward>::NumDimensions>,
@@ -75,7 +75,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
+ const OutputBackward>>,
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
@@ -83,7 +83,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const array<
typename internal::traits<OutputBackward>::Index, 5>,
const TensorReverseOp<const Eigen::array<bool, 5>,
- const Kernel> > > > > > >::type
+ const Kernel>>>>>>>::type
CuboidConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward,
typename internal::traits<OutputBackward>::Index inputPlanes,
@@ -94,12 +94,12 @@ CuboidConvolutionBackwardInput(
typedef typename internal::traits<OutputBackward>::Index TensorIndex;
const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
internal::traits<Kernel>::NumDimensions,
- internal::traits<Kernel>::Layout, TensorIndex> >
+ internal::traits<Kernel>::Layout, TensorIndex>>
kern(kernel);
const TensorRef<
const Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
@@ -239,8 +239,8 @@ CuboidConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // kernelPlanes, kernelRows and kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, kernelPlanes, kernelRows and kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -323,35 +323,69 @@ CuboidConvolutionBackwardInput(
*/
template <typename OutputBackward, typename Input>
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
- internal::traits<OutputBackward>::Layout == ColMajor,
- TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 5>,
- const TensorContractionOp<
- const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const Input> > > > >,
- TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 5>,
- const TensorContractionOp<
- const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const Input> > >,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const OutputBackward> > > >::type
+ internal::traits<Input>::Layout == ColMajor,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const TensorVolumePatchOp<
+ Dynamic, Dynamic, Dynamic,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>>>>>,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const TensorVolumePatchOp<
+ Dynamic, Dynamic, Dynamic,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>>>>>>::type
CuboidConvolutionBackwardKernel(
const Input& input, const OutputBackward& output_backward,
typename internal::traits<Input>::Index kernelPlanes,
@@ -362,11 +396,11 @@ CuboidConvolutionBackwardKernel(
typedef typename internal::traits<Input>::Index TensorIndex;
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
internal::traits<Input>::NumDimensions,
- internal::traits<Input>::Layout, TensorIndex> >
+ internal::traits<Input>::Layout, TensorIndex>>
in(input);
TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
@@ -380,6 +414,13 @@ CuboidConvolutionBackwardKernel(
internal::traits<OutputBackward>::NumDimensions,
YOU_MADE_A_PROGRAMMING_MISTAKE);
+ // We do not support higher dimensional backward convolutions, or convolutions
+ // without batch dimension.
+ // TODO(ezhulenev): Relax this constraint, and turn on tests without batch
+ // dimension in eigen_backward_cuboid_convolutions_test.cc.
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5,
+ YOU_MADE_A_PROGRAMMING_MISTAKE);
+
const TensorIndex inputPlanes =
isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
const TensorIndex inputRows =
@@ -401,6 +442,10 @@ CuboidConvolutionBackwardKernel(
const TensorIndex kernelChannels =
isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
+ // Number of batches in the input tensor.
+ const TensorIndex batch =
+ isColMajor ? in.dimension(4) : in.dimension(NumDims - 5);
+
// TODO(ezhulenev): Add support for inflated strides. Without inflated strides
// effective kernel planes/rows/cols are always the same as the kernel itself
// (see eigen_spatial_convolutions for details).
@@ -408,6 +453,7 @@ CuboidConvolutionBackwardKernel(
const TensorIndex kernelRowsEff = kernelRows;
const TensorIndex kernelColsEff = kernelCols;
+ // Compute forward padding from input and output_backward dimensions.
const TensorIndex padPlanes = numext::maxi<Index>(
0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
const TensorIndex padRows = numext::maxi<Index>(
@@ -416,92 +462,147 @@ CuboidConvolutionBackwardKernel(
0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
const TensorIndex padding_top_z = padPlanes / 2;
- const TensorIndex padding_bottom_z = padPlanes - padding_top_z;
const TensorIndex padding_top = padRows / 2;
- const TensorIndex padding_bottom = padRows - padding_top;
const TensorIndex padding_left = padCols / 2;
- const TensorIndex padding_right = padCols - padding_left;
- // Reshaped output_backward before contraction.
- DSizes<TensorIndex, 2> output_dims;
+ // Compute paddings for output_backward before extracting patches.
+ const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1;
+ const auto expanded_out_rows = (outputRows - 1) * strideRows + 1;
+ const auto expanded_out_cols = (outputCols - 1) * strideCols + 1;
+ const auto padded_out_planes = inputPlanes + kernelPlanes - 1;
+ const auto padded_out_rows = inputRows + kernelRows - 1;
+ const auto padded_out_cols = inputCols + kernelCols - 1;
+ const auto top_pad_planes = kernelPlanes - 1 - padding_top_z;
+ const auto top_pad_rows = kernelRows - 1 - padding_top;
+ const auto left_pad_cols = kernelCols - 1 - padding_left;
+ const auto bottom_pad_planes =
+ padded_out_planes - expanded_out_planes - top_pad_planes;
+ const auto bottom_pad_rows =
+ padded_out_rows - expanded_out_rows - top_pad_rows;
+ const auto right_pad_cols =
+ padded_out_cols - expanded_out_cols - left_pad_cols;
+
+ // Reorder output_backward dimensions.
+ array<TensorIndex, 5> output_backward_shuffle;
if (isColMajor) {
- output_dims[0] = kernelFilters;
- output_dims[1] = outputPlanes * outputRows * outputCols;
- for (int i = 4; i < NumDims; ++i) {
- output_dims[1] *= out.dimension(i);
- }
+ // From: [out_depth, out_planes, out_rows, out_cols, batch]
+ // To: [batch, out_planes, out_rows, out_cols, out_depth]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
} else {
- output_dims[1] = kernelFilters;
- output_dims[0] = outputCols * outputRows * outputPlanes;
- for (int i = 0; i < NumDims - 4; ++i) {
- output_dims[0] *= out.dimension(i);
- }
+ // From: [batch, out_cols, out_rows, out_planes, out_depth]
+ // To: [out_depth, out_cols, out_rows, out_planes, batch]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
}
- // Reshaped extract_volume_patches(in)
- DSizes<TensorIndex, 2> pre_contract_dims;
+ // Reorder input dimensions.
+ array<TensorIndex, 5> input_shuffle;
if (isColMajor) {
- pre_contract_dims[0] =
- kernelChannels * kernelPlanes * kernelRows * kernelCols;
- pre_contract_dims[1] = outputPlanes * outputRows * outputCols;
- for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[1] *= in.dimension(i);
- }
- eigen_assert(output_dims[1] == pre_contract_dims[1]);
+ // From: [in_depth, in_planes, in_rows, in_cols, batch]
+ // To: [in_depth, batch, in_planes, in_rows, in_cols]
+ input_shuffle = {0, 4, 1, 2, 3};
} else {
- pre_contract_dims[1] =
- kernelCols * kernelRows * kernelPlanes * kernelChannels;
- pre_contract_dims[0] = outputCols * outputRows * outputPlanes;
- for (int i = 0; i < NumDims - 4; ++i) {
- pre_contract_dims[0] *= in.dimension(i);
- }
- eigen_assert(output_dims[0] == pre_contract_dims[0]);
+ // From: [batch, in_cols, in_rows, in_planes, in_depth]
+ // To: [in_cols, in_rows, in_planes, batch, in_depth]
+ input_shuffle = {1, 2, 3, 0, 4};
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
+ // Input is playing the role of a "kernel" in this convolution.
+ DSizes<TensorIndex, 2> input_dims;
+ if (isColMajor) {
+ input_dims[0] = kernelChannels;
+ input_dims[1] = batch * inputPlanes * inputRows * inputCols;
+ } else {
+ input_dims[1] = kernelChannels;
+ input_dims[0] = inputCols * inputRows * inputPlanes * batch;
+ }
+ // Molds the output of the patch extraction result into a 2D tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the
+ // kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols;
+ pre_contract_dims[1] =
+ kernelPlanes * kernelRows * kernelCols * kernelFilters;
+ } else {
+ pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch;
+ pre_contract_dims[0] =
+ kernelFilters * kernelCols * kernelRows * kernelPlanes;
+ }
+
+ // We will contract along the collapsed dimension that contains the
+ // batch, inputPlanes, inputRows and inputCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- DSizes<TensorIndex, 5> kernel_dims;
+ // Dimensions after contraction.
+ DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
- kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelRows;
- kernel_dims[4] = kernelCols;
+ post_contract_dims[0] = kernelChannels;
+ post_contract_dims[1] = kernelPlanes;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelCols;
+ post_contract_dims[4] = kernelFilters;
} else {
- kernel_dims[4] = kernelFilters;
- kernel_dims[3] = kernelChannels;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[1] = kernelRows;
- kernel_dims[0] = kernelCols;
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = kernelCols;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelPlanes;
+ post_contract_dims[4] = kernelChannels;
}
- return choose(
- Cond<internal::traits<Input>::Layout == ColMajor>(),
- output_backward.reshape(output_dims)
- .contract(input
+ // Reorder output of contraction to valid filter shape.
+ array<TensorIndex, 5> kernel_shuffle;
+ if (isColMajor) {
+ // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth]
+ // To: [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols]
+ kernel_shuffle = {4, 0, 1, 2, 3};
+ } else {
+ // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth]
+ // To: [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth]
+ kernel_shuffle = {1, 2, 3, 4, 0};
+ }
+
+ // Reverse kernel backprop dimensions.
+ array<TensorIndex, 5> kernel_reverse;
+ if (isColMajor) {
+ kernel_reverse = {false, false, true, true, true};
+ } else {
+ kernel_reverse = {true, true, true, false, false};
+ }
+
+ // Create convolution input (aka source of patches) from output backward
+ // tensor by shuffling dimensions.
+ const auto the_input =
+ output_backward.shuffle(output_backward_shuffle).eval();
+
+ // Create convolution kernel (aka filter) from input by shuffling and
+ // reshaping.
+ const auto the_kernel =
+ input.shuffle(input_shuffle).reshape(input_dims).eval();
+
+ return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+ the_kernel.contract(
+ the_input
.extract_volume_patches(
- kernelPlanes, kernelRows, kernelCols, stridePlanes,
- strideRows, strideCols, 1, 1, 1, padding_top_z,
- padding_bottom_z, padding_top, padding_bottom,
- padding_left, padding_right)
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
- contract_dims)
- .reshape(kernel_dims),
- input
- .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
- stridePlanes, strideRows, strideCols, 1, 1, 1,
- padding_top_z, padding_bottom_z, padding_top,
- padding_bottom, padding_left, padding_right)
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
- .contract(output_backward.reshape(output_dims), contract_dims)
- .reshape(kernel_dims));
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols,
+ top_pad_planes, bottom_pad_planes, top_pad_rows,
+ bottom_pad_rows, left_pad_cols, right_pad_cols)
+ .reshape(pre_contract_dims),
+ contract_dims),
+ the_input
+ .extract_volume_patches(
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols, top_pad_planes,
+ bottom_pad_planes, top_pad_rows, bottom_pad_rows,
+ left_pad_cols, right_pad_cols)
+ .reshape(pre_contract_dims)
+ .contract(the_kernel, contract_dims))
+ .reshape(post_contract_dims)
+ .shuffle(kernel_shuffle)
+ .reverse(kernel_reverse);
}
} // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index 8d06107553..960920c55b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -238,8 +238,8 @@ SpatialConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // the kernelRows and the kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, the kernelRows and the kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -332,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic,
- const Input> > > > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >,
TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 4>,
const TensorContractionOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward> > > >::type
@@ -456,12 +449,16 @@ SpatialConvolutionBackwardKernel(
eigen_assert(output_dims[0] == pre_contract_dims[0]);
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
-
+ // We will contract along the collapsed dimension that contains the
+ // outputCols, outputRows and OTHERS.
array<IndexPair<TensorIndex>, 1> contract_dims;
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+ if (isColMajor) {
+ // col-major: output_backward.contract(input.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+ } else {
+ // row-major: input.patches.contract(output_backward)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ }
// After the contraction, the kernel will have the desired shape
// out_depth X in_shape X kernel_rows X kernel_cols
@@ -487,8 +484,7 @@ SpatialConvolutionBackwardKernel(
kernelRows, kernelCols, row_stride, col_stride,
row_in_stride, col_in_stride, 1, 1, padding_top,
padding_bottom, padding_left, padding_right, OutScalar(0))
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
+ .reshape(pre_contract_dims),
contract_dims)
.reshape(kernel_dims),
input
@@ -497,7 +493,6 @@ SpatialConvolutionBackwardKernel(
padding_top, padding_bottom, padding_left,
padding_right, OutScalar(0))
.reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
.contract(output_backward.reshape(output_dims), contract_dims)
.reshape(kernel_dims));
}
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
index 2229ec9659..673ec1458b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
@@ -1248,11 +1248,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+ /*num_batches*/ 1);
Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
patch_cols);
- Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
- output_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, /*num_batches*/ 1);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1282,9 +1285,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(id, i, j, k) *
- output_backward(od, output_i, output_j, output_k);
+ expected += input(id, i, j, k, /*batch*/ 0) *
+ output_backward(od, output_i, output_j,
+ output_k, /*batch*/ 0);
}
}
}
@@ -1311,12 +1314,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes,
- input_depth);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows,
+ input_planes, input_depth);
Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
input_depth, output_depth);
- Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
- output_planes, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1346,9 +1351,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(k, j, i, id) *
- output_backward(output_k, output_j, output_i, od);
+ expected += input(/*batch*/ 0, k, j, i, id) *
+ output_backward(/*batch*/ 0, output_k, output_j,
+ output_i, od);
}
}
}
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index 62e9f9123d..6a9a2accd8 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -21,6 +21,1412 @@ limitations under the License.
namespace Eigen {
+namespace internal {
+
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract volume patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
+// 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar_,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionInputMapper<
+ Scalar_, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef Scalar_ Scalar;
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper VectorMapper;
+ typedef SubMapper LinearMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(
+ const TensorEvaluator<
+ const TensorReshapingOp<
+ NewDimension,
+ const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
+ Device>& tensor,
+ const nocontract_t&, const nocontract_t&, const contract_t&,
+ const contract_t&)
+ : m_impl(tensor.impl().impl()) {
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_patch_depth = tensor.impl().dimensions()[0];
+ m_patch_planes = tensor.impl().dimensions()[1];
+ m_patch_rows = tensor.impl().dimensions()[2];
+ m_patch_cols = tensor.impl().dimensions()[3];
+ m_num_patches = tensor.impl().dimensions()[4];
+ } else {
+ const int NumDims = tensor.impl().dimensions().size();
+ m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
+ m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
+ m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
+ m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
+ m_num_patches = tensor.impl().dimensions()[NumDims - 5];
+ }
+
+ // Strides for navigating through the single patch.
+ m_patch_plane_stride = m_patch_depth;
+ m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
+ m_patch_col_stride = m_patch_rows * m_patch_row_stride;
+
+ // Strides for the output tensor.
+ // IMPORTANT: These strides are used to locate an element in a patch at a
+ // depth zero (channel), which is not quite the same as "traditional"
+ // stride.
+ m_rowStride = m_patch_planes;
+ m_colStride = m_patch_rows * m_rowStride;
+ m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
+ m_otherStride = m_patchStride * m_num_patches;
+
+ m_outputPlanes = tensor.impl().outputPlanes();
+ m_outputRows = tensor.impl().outputRows();
+ m_outputCols = tensor.impl().outputCols();
+
+ m_outputPlanesRows = m_outputPlanes * m_outputRows;
+
+ m_plane_strides = tensor.impl().userPlaneStride();
+ m_row_strides = tensor.impl().userRowStride();
+ m_col_strides = tensor.impl().userColStride();
+
+ m_in_plane_strides = tensor.impl().userInPlaneStride();
+ m_in_row_strides = tensor.impl().userInRowStride();
+ m_in_col_strides = tensor.impl().userInColStride();
+
+ m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
+ m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
+ m_patch_col_inflate_strides = tensor.impl().colInflateStride();
+
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_inputDepth = tensor.impl().impl().dimensions()[0];
+ m_inputPlanes = tensor.impl().impl().dimensions()[1];
+ m_inputRows = tensor.impl().impl().dimensions()[2];
+ m_inputCols = tensor.impl().impl().dimensions()[3];
+ } else {
+ const int NumDims = tensor.impl().impl().dimensions().size();
+ m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
+ m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
+ m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
+ m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
+ }
+
+ // Strides for navigating through the input tensor.
+ m_planeInputStride = m_inputDepth;
+ m_rowInputStride = m_inputDepth * m_inputPlanes;
+ m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
+ m_patchInputStride =
+ m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
+
+ m_planePaddingTop = tensor.impl().planePaddingTop();
+ m_rowPaddingTop = tensor.impl().rowPaddingTop();
+ m_colPaddingLeft = tensor.impl().colPaddingLeft();
+
+ m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+
+ m_fastPatchPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_stride);
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
+
+ m_fastInputPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
+ m_fastInputRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
+ m_fastInputColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
+
+ m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
+ m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
+
+ m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
+
+ m_fastOutputPlanesRows =
+ internal::TensorIntDivisor<Index>(m_outputPlanesRows);
+ }
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
+ : m_impl(base_mapper.m_impl) {
+ m_patch_depth = base_mapper.m_patch_depth;
+ m_patch_planes = base_mapper.m_patch_planes;
+ m_patch_rows = base_mapper.m_patch_rows;
+ m_patch_cols = base_mapper.m_patch_cols;
+ m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_plane_stride = base_mapper.m_patch_plane_stride;
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
+ m_rowStride = base_mapper.m_rowStride;
+ m_colStride = base_mapper.m_colStride;
+ m_patchStride = base_mapper.m_patchStride;
+ m_otherStride = base_mapper.m_otherStride;
+
+ m_planeInputStride = base_mapper.m_planeInputStride;
+ m_rowInputStride = base_mapper.m_rowInputStride;
+ m_colInputStride = base_mapper.m_colInputStride;
+ m_patchInputStride = base_mapper.m_patchInputStride;
+ m_otherInputStride = base_mapper.m_otherInputStride;
+
+ m_inputDepth = base_mapper.m_inputDepth;
+ m_inputPlanes = base_mapper.m_inputPlanes;
+ m_inputRows = base_mapper.m_inputRows;
+ m_inputCols = base_mapper.m_inputCols;
+
+ m_outputPlanes = base_mapper.m_outputPlanes;
+ m_outputRows = base_mapper.m_outputRows;
+ m_outputCols = base_mapper.m_outputCols;
+
+ m_plane_strides = base_mapper.m_plane_strides;
+ m_row_strides = base_mapper.m_row_strides;
+ m_col_strides = base_mapper.m_col_strides;
+
+ m_in_plane_strides = base_mapper.m_in_plane_strides;
+ m_in_row_strides = base_mapper.m_in_row_strides;
+ m_in_col_strides = base_mapper.m_in_col_strides;
+
+ m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
+ m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
+ m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
+
+ m_planePaddingTop = base_mapper.m_planePaddingTop;
+ m_rowPaddingTop = base_mapper.m_rowPaddingTop;
+ m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+
+ m_outputPlanesRows = base_mapper.m_outputPlanesRows;
+
+ m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
+ m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
+ m_fastInputRowStride = base_mapper.m_fastInputRowStride;
+ m_fastInputColStride = base_mapper.m_fastInputColStride;
+ m_fastRowStride = base_mapper.m_fastRowStride;
+ m_fastColStride = base_mapper.m_fastColStride;
+ m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
+ m_fastOutputRows = base_mapper.m_fastOutputRows;
+ m_fastOutputCols = base_mapper.m_fastOutputCols;
+ m_fastDimZero = base_mapper.m_fastDimZero;
+ m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
+ }
+
+ // If true, turns off some optimizations for loading packets since the image
+ // patches are "non-standard" such as there are non-trivial strides or
+ // inflations in the input.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
+ m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
+ m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
+ return SubMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+ return LinearMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the coefficient at the patchIndex location instead of the usual
+ // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
+ // gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the packet at the patchIndex location instead of the usual m_rowIndex,
+ // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
+ return m_impl;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ private:
+ friend class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>;
+
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset * m_in_col_strides;
+ const Index origInputCol =
+ (m_patch_col_inflate_strides == 1)
+ ? inputCol
+ : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
+ const Index origInputRow =
+ (m_patch_row_inflate_strides == 1)
+ ? inputRow
+ : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
+
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
+ const Index origInputPlane =
+ (m_patch_plane_inflate_strides == 1)
+ ? inputPlane
+ : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
+
+ if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
+ origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
+ origInputPlane >= m_inputPlanes ||
+ (inputCol != origInputCol * m_patch_col_inflate_strides) ||
+ (inputRow != origInputRow * m_patch_row_inflate_strides) ||
+ (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + origInputPlane * m_planeInputStride +
+ origInputRow * m_rowInputStride +
+ origInputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ eigen_assert(!nonStandardPatches());
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
+ inputRow >= m_inputRows || inputPlane < 0 ||
+ inputPlane >= m_inputPlanes) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ if (nonStandardPatches()) {
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+ return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+ eigen_assert(!nonStandardPatches());
+
+ if ((patchDepth() % packetSize) == 0) {
+ return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ } else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
+ const Index patchOffsets[2] = {
+ patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
+
+ const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
+ patchOffsets[1] / m_fastColStride};
+ eigen_assert(colOffsets[0] <= colOffsets[1]);
+
+ const Index inputCols[2] = {colIndex + colOffsets[0],
+ colIndex + colOffsets[1]};
+ if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputCols[0] == inputCols[1]) {
+ const Index rowOffsets[2] = {
+ (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
+ (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
+ eigen_assert(rowOffsets[0] <= rowOffsets[1]);
+ const Index inputRows[2] = {rowIndex + rowOffsets[0],
+ rowIndex + rowOffsets[1]};
+
+ if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputRows[0] == inputRows[1]) {
+ const Index planeOffsets[2] = {
+ patchOffsets[0] - colOffsets[0] * m_colStride -
+ rowOffsets[0] * m_rowStride,
+ patchOffsets[1] - colOffsets[1] * m_colStride -
+ rowOffsets[1] * m_rowStride};
+ eigen_assert(planeOffsets[0] <= planeOffsets[1]);
+ const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
+ planeIndex + planeOffsets[1]};
+
+ if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
+ const Index depth = patchId - patchOffsets[0] * patchDepth();
+ const Index inputIndex =
+ depth + inputPlanes[0] * m_planeInputStride +
+ inputRows[0] * m_rowInputStride +
+ inputCols[0] * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+ }
+ }
+ }
+
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ eigen_assert(!nonStandardPatches());
+ eigen_assert((patchDepth() % packetSize) == 0);
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+ eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
+ inputCol >= m_inputCols || inputRow >= m_inputRows ||
+ inputPlane >= m_inputPlanes) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
+ Index colIndex, Index otherIndex) const {
+ const int packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_MAX
+ typename internal::remove_const<Scalar>::type values[packetSize];
+ for (int i = 0; i < packetSize; ++i) {
+ values[i] =
+ loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+ Packet rslt = internal::pload<Packet>(values);
+ return rslt;
+ }
+
+ // Precompute the indices (plane, row, col, other) of the first element of
+ // the given patch index, within the output tensor of the TensorVolumePatchOp.
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
+ Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
+ Index& otherIndex) const {
+ const size_t NumInputDims = array_size<
+ typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
+
+ // Check if patchIndex might contain batch and other dimensions.
+ otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
+
+ // Compute index of the patch within the batch (and other dimensions).
+ const Index patch3DIndex = (NumInputDims == 4)
+ ? patchIndex
+ : (patchIndex - otherIndex * m_num_patches);
+
+ otherIndex *= m_patchInputStride;
+
+ colIndex = patch3DIndex / m_fastOutputPlanesRows;
+ rowIndex =
+ (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
+ planeIndex =
+ patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
+
+ colIndex = colIndex * m_col_strides - m_colPaddingLeft;
+ rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
+ planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
+ }
+
+ Index m_patch_depth; // number of channels in the patch
+ Index m_patch_planes; // number of planes in the patch
+ Index m_patch_rows; // number of rows in the patch
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract
+
+ // Strides for navigating through the single patch.
+ Index m_patch_plane_stride;
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+
+ // Strides for the output tensor (depth is not the part of the stride).
+ Index m_rowStride;
+ Index m_colStride;
+ Index m_patchStride;
+ Index m_otherStride;
+
+ Index m_planeInputStride; // Plane stride in the input tensor
+ Index m_rowInputStride; // Row stride in the input tensor
+ Index m_colInputStride; // Col stride in the input tensor
+ Index m_patchInputStride; // Patch stride in the input tensor
+ Index m_otherInputStride;
+
+ Index m_inputDepth; // Depth of the input tensor
+ Index m_inputPlanes; // Number of planes in the input tensor
+ Index m_inputRows; // Number of rows in the input tensor
+ Index m_inputCols; // Number of cols in the input tensor
+
+ Index m_outputPlanes; // Number of output planes
+ Index m_outputRows; // Number of output rows
+ Index m_outputCols; // Number of output cols
+ Index m_outputPlanesRows; // Cached outputPlanes * outputRows.
+
+ Index m_plane_strides; // User specified plane stride
+ Index m_row_strides; // User specified row stride
+ Index m_col_strides; // User specified col stride
+
+ // User specified plane/row/col atrous convolution strides.
+ Index m_in_plane_strides;
+ Index m_in_row_strides;
+ Index m_in_col_strides;
+
+ // User specified plane/row/col inflation strides in the image patch.
+ Index m_patch_plane_inflate_strides;
+ Index m_patch_row_inflate_strides;
+ Index m_patch_col_inflate_strides;
+
+ Index m_planePaddingTop; // Plane padding
+ Index m_rowPaddingTop; // Row padding
+ Index m_colPaddingLeft; // Column padding
+
+ // Fast representation of various divisors.
+ internal::TensorIntDivisor<Index> m_fastNumPatches;
+
+ internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
+ internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastInputRowStride;
+ internal::TensorIntDivisor<Index> m_fastInputColStride;
+
+ internal::TensorIntDivisor<Index> m_fastRowStride;
+ internal::TensorIntDivisor<Index> m_fastColStride;
+
+ internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth
+ internal::TensorIntDivisor<Index> m_fastOutputPlanes;
+ internal::TensorIntDivisor<Index> m_fastOutputRows;
+ internal::TensorIntDivisor<Index> m_fastOutputCols;
+ internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
+
+ const TensorEvaluator<ArgType, Device> m_impl;
+};
+
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename packet_traits<Scalar>::half HalfPacket;
+
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ ParentMapper;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef Self LinearMapper;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper),
+ m_depth_offset(vert_offset),
+ m_col_offset(horiz_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const Self& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper.m_base_mapper),
+ m_depth_offset(vert_offset + base_mapper.m_depth_offset),
+ m_col_offset(horiz_offset + base_mapper.m_col_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
+ return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
+ Index j) const {
+ return m_base_mapper(i + m_depth_offset, j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
+ return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
+ Index j) const {
+ return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
+ j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
+ loadCoeffStandard(Index i) const {
+ return m_base_mapper.loadCoeffStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
+ return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ loadPacketStandard(Index i) const {
+ return m_base_mapper.loadPacketStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC bool aligned(Index) const {
+ return false;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_base_mapper.nonStandardPatches();
+ }
+
+ // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
+ // plane and depth index respectively that fits into the peeled_k elements
+ // starting at m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
+ const Index row) const {
+ const Index max_plane = fastPatchPlaneStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride() -
+ row * patchRowStride());
+ return std::min<Index>(1 + max_plane, patchPlanes());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input plane
+ // stride. Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const {
+ eigen_assert(m_base_mapper.m_patch_depth ==
+ m_base_mapper.m_planeInputStride &&
+ "Patch depth must be equal to plane input stride.");
+ return m_base_mapper.m_planeInputStride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const {
+ eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
+ "Patch planes must be equal to row stride.");
+ return m_base_mapper.m_rowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const {
+ return m_base_mapper.m_patch_rows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const {
+ return m_base_mapper.m_patch_cols;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ return m_base_mapper.m_patch_row_stride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ return m_base_mapper.m_fastPatchRowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
+ const Index p = m_planeIndex + plane;
+ return p < 0 || p >= m_base_mapper.m_inputPlanes;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
+ const Index r = m_rowIndex + row;
+ return r < 0 || r >= m_base_mapper.m_inputRows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
+ const Index c = m_colIndex + col;
+ return c < 0 || c >= m_base_mapper.m_inputCols;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
+ const Index col) const {
+ const Index p = m_planeIndex + plane;
+ const Index r = m_rowIndex + row;
+ const Index c = m_colIndex + col;
+ return p * m_base_mapper.m_planeInputStride +
+ r * m_base_mapper.m_rowInputStride +
+ c * m_base_mapper.m_colInputStride + m_otherIndex;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index planeOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ const Index planeOffset = patchOffset -
+ colOffset * m_base_mapper.m_colStride -
+ rowOffset * m_base_mapper.m_rowStride;
+ return planeOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index rowOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ return rowOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index colOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ return colOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index depthOffset() const {
+ return m_depth_offset % patchDepth();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
+ getLinearMapper(Index i, Index j) const {
+ return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
+ }
+
+ private:
+ const ParentMapper& m_base_mapper;
+ Index m_depth_offset; // First row in the input matrix
+ Index m_col_offset; // First col in the input matrix
+
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_planeIndex;
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
+};
+
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted volume patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A8 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
+// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
+//
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment, int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+
+ typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if ((packet_size % 4) == 0 && !non_standard_patches) {
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
+
+ const bool pad_row0 = pad_col0 || dm0.padRow(r);
+ const bool pad_row1 = pad_col1 || dm1.padRow(r);
+ const bool pad_row2 = pad_col2 || dm2.padRow(r);
+ const bool pad_row3 = pad_col3 || dm3.padRow(r);
+
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
+
+ const bool pad0 = pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
+ } else {
+ // Packet can span multiple planes, rows or columns, so we have to go
+ // though the slower "standard" path.
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketStandard(k);
+ kernel.packet[1] = dm1.loadPacketStandard(k);
+ kernel.packet[2] = dm2.loadPacketStandard(k);
+ kernel.packet[3] = dm3.loadPacketStandard(k);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Template specialization for packet_size = 2. We must special-case packet
+// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const int packet_size = 2;
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if (!non_standard_patches) {
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
+
+ const bool pad_row0 = dm0.padRow(r);
+ const bool pad_row1 = dm1.padRow(r);
+ const bool pad_row2 = dm2.padRow(r);
+ const bool pad_row3 = dm3.padRow(r);
+
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
+
+ const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
+ } else {
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketStandard(k);
+ kernel0.packet[1] = dm1.loadPacketStandard(k);
+ kernel1.packet[0] = dm2.loadPacketStandard(k);
+ kernel1.packet[1] = dm3.loadPacketStandard(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!rhs.nonStandardPatches()) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Special case for non-vectorized types such as float16 (packet_size = 1).
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
+ Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const Index packet_cols4 = (cols / 4) * 4;
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ if (!rhs.nonStandardPatches()) {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+} // namespace internal
+
/** CuboidConvolution
* \ingroup CXX11_NeuralNetworks_Module
*
@@ -98,7 +1504,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
// Spatial size of the kernel.
- const TensorIndex kernelDepth =
+ const TensorIndex kernelPlanes =
isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
const TensorIndex kernelRows =
isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
@@ -118,27 +1524,27 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
const TensorIndex inputCols =
isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
- TensorIndex out_depth;
+ TensorIndex out_planes;
TensorIndex out_height;
TensorIndex out_width;
switch (padding_type) {
case PADDING_VALID:
- out_depth = Eigen::divup(inputPlanes - kernelDepth + 1,
- static_cast<TensorIndex>(stridePlanes));
+ out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
+ static_cast<TensorIndex>(stridePlanes));
out_height = Eigen::divup(inputRows - kernelRows + 1,
static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols - kernelCols + 1,
static_cast<TensorIndex>(strideCols));
break;
case PADDING_SAME:
- out_depth =
+ out_planes =
Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
out_height =
Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
break;
default:
- out_depth = 0;
+ out_planes = 0;
out_height = 0;
out_width = 0;
eigen_assert(false && "unexpected padding");
@@ -147,9 +1553,9 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> kernel_dims;
if (isColMajor) {
kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
} else {
- kernel_dims[0] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
kernel_dims[1] = kernelFilters;
}
@@ -160,15 +1566,15 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
pre_contract_dims[0] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[1] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_planes * out_height * out_width;
for (int i = 4; i < NumDims; ++i) {
pre_contract_dims[1] *= in.dimension(i);
}
} else {
pre_contract_dims[1] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[0] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_planes * out_height * out_width;
for (int i = 0; i < NumDims - 4; ++i) {
pre_contract_dims[0] *= in.dimension(i);
}
@@ -187,7 +1593,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
post_contract_dims[0] = kernelFilters;
- post_contract_dims[1] = out_depth;
+ post_contract_dims[1] = out_planes;
post_contract_dims[2] = out_height;
post_contract_dims[3] = out_width;
for (int i = 4; i < NumDims; ++i) {
@@ -195,7 +1601,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
}
} else {
post_contract_dims[NumDims - 1] = kernelFilters;
- post_contract_dims[NumDims - 2] = out_depth;
+ post_contract_dims[NumDims - 2] = out_planes;
post_contract_dims[NumDims - 3] = out_height;
post_contract_dims[NumDims - 4] = out_width;
for (int i = 0; i < NumDims - 4; ++i) {
@@ -208,13 +1614,13 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
kernel.reshape(kernel_dims)
.contract(input
.extract_volume_patches(
- kernelDepth, kernelRows, kernelCols, stridePlanes,
+ kernelPlanes, kernelRows, kernelCols, stridePlanes,
strideRows, strideCols, padding_type)
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
input
- .extract_volume_patches(kernelDepth, kernelRows, kernelCols,
+ .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
stridePlanes, strideRows, strideCols,
padding_type)
.reshape(pre_contract_dims)
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index a4dff4b91c..e926d73f87 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -22,8 +22,36 @@ namespace Eigen {
namespace internal {
-// TODO: Consolidate this part of the code with the image patch extraction code
-// since they are both very similar.
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract image patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelRows * kernelCols;
+// 1: out_height * out_width; * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+// TODO(ezhulenev): Consolidate this part of the code with the image patch
+// extraction code since they are both very similar.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar_, typename Index,
typename nocontract_t, typename contract_t, int Side, int packet_size,
@@ -77,12 +105,17 @@ class TensorContractionInputMapper<
m_patch_cols = tensor.impl().dimensions()[2];
m_num_patches = tensor.impl().dimensions()[3];
} else {
- const int NumDims = tensor.impl().dimensions().size();
+ const size_t NumDims = tensor.impl().dimensions().size();
patch_depth = tensor.impl().dimensions()[NumDims - 1];
patch_rows = tensor.impl().dimensions()[NumDims - 2];
m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
m_num_patches = tensor.impl().dimensions()[NumDims - 4];
}
+
+ // Strides for navigating through the single patch.
+ m_patch_row_stride = patch_depth;
+ m_patch_col_stride = patch_rows * m_patch_row_stride;
+
m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
m_patch_col_inflate_strides = tensor.impl().colInflateStride();
@@ -111,6 +144,10 @@ class TensorContractionInputMapper<
m_rowPaddingTop = tensor.impl().rowPaddingTop();
m_colPaddingLeft = tensor.impl().colPaddingLeft();
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
m_fastInputRowStride =
internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
m_fastInputColStride =
@@ -126,6 +163,10 @@ class TensorContractionInputMapper<
: m_impl(base_mapper.m_impl) {
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
@@ -148,6 +189,8 @@ class TensorContractionInputMapper<
m_rowPaddingTop = base_mapper.m_rowPaddingTop;
m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
m_fastNumPatches = base_mapper.m_fastNumPatches;
@@ -238,6 +281,8 @@ class TensorContractionInputMapper<
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>;
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
Index colIndex, Index otherIndex) const {
@@ -250,6 +295,7 @@ class TensorContractionInputMapper<
(m_patch_col_inflate_strides == 1)
? inputCol
: ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
const Index rowOffset = patchOffset - colOffset * m_colStride;
const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
const Index origInputRow =
@@ -268,6 +314,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
Index colIndex,
@@ -276,10 +324,9 @@ class TensorContractionInputMapper<
// Find the offset of the element wrt the location of the first element.
const Index patchOffset = patchId / m_fastDimZero;
-
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
inputRow >= m_inputRows) {
@@ -291,6 +338,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
Index colIndex,
@@ -318,12 +367,14 @@ class TensorContractionInputMapper<
if ((patchDepth() % packetSize) == 0) {
return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
} else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
const Index patchOffsets[2] = {
patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
patchOffsets[1] / m_fastColStride};
-
const Index inputCols[2] = {colIndex + colOffsets[0],
colIndex + colOffsets[1]};
if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
@@ -371,8 +422,8 @@ class TensorContractionInputMapper<
eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
inputRow >= m_inputRows) {
@@ -401,7 +452,7 @@ class TensorContractionInputMapper<
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
Index patchIndex, Index& rowIndex, Index& colIndex,
Index& otherIndex) const {
- const int NumInputDims = array_size<
+ const size_t NumInputDims = array_size<
typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
const Index patch2DIndex = (NumInputDims == 3)
@@ -414,8 +465,15 @@ class TensorContractionInputMapper<
rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
}
- Index m_patch_cols; // number of colums in the patch
- Index m_num_patches; // number of patches to extract.
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract.
+
+ // Strides for navigating through the single patch.
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
Index m_patch_row_inflate_strides; // the strides for row inflation in the
// image patch
Index m_patch_col_inflate_strides; // the strides for col inflation in the
@@ -549,6 +607,40 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
+ // index respectively that fits into the peeled_k elements starting at
+ // m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input row stride.
+ // Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
return m_base_mapper.m_rowInputStride;
@@ -563,6 +655,28 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -603,8 +717,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -617,12 +730,44 @@ class TensorContractionSubMapper<
Index m_depth_offset; // First row in the input matrix
Index m_col_offset; // First col in the input matrix
- Index m_rowIndex; // precomputed row index corresponding to the col offset
- Index m_colIndex; // precomputed col index corresponding to the col offset
- Index
- m_otherIndex; // precomputed other index corresponding to the col offset
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
};
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted image patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A8 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
+// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
+//
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar, typename Index,
typename nocontract_t, typename contract_t, int packet_size,
@@ -649,9 +794,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -660,9 +805,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -675,30 +817,27 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows, if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -709,14 +848,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -738,19 +876,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
@@ -767,6 +895,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -824,9 +954,9 @@ struct gemm_pack_rhs<
Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -835,9 +965,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
@@ -851,30 +978,27 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -885,14 +1009,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -916,22 +1039,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -949,7 +1062,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -968,7 +1083,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1006,8 +1121,7 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1016,8 +1130,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1045,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h
index a3d795813d..80ab745bfe 100644
--- a/tensorflow/core/kernels/eigen_volume_patch.h
+++ b/tensorflow/core/kernels/eigen_volume_patch.h
@@ -43,6 +43,7 @@ struct CustomTensorEvaluator {
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = false,
+ PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = NumDims == 6,
RawAccess = false
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.cc b/tensorflow/core/kernels/extract_volume_patches_op.cc
new file mode 100644
index 0000000000..52cd078a35
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/*
+See extract_image_patches_op* files and docs for extract_image_patches in
+../ops/image_ops.cc.
+
+Rates are not supported as of now, but the comments hint how to edit the code
+when rates are to be added.
+*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include <vector>
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+static inline void ParseAttributeVec5(OpKernelConstruction* context,
+ const string& attr_name,
+ std::vector<int32>* attr) {
+ OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr));
+ OP_REQUIRES(
+ context, (*attr)[0] == 1 && (*attr)[4] == 1,
+ errors::Unimplemented("Only support ", attr_name, " across space."));
+ OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1 && (*attr)[3] >= 1,
+ errors::OutOfRange(attr_name, " is out of range."));
+}
+
+template <typename Device, typename T>
+class ExtractVolumePatchesOp : public UnaryOp<T> {
+ public:
+ explicit ExtractVolumePatchesOp(OpKernelConstruction* context)
+ : UnaryOp<T>(context) {
+ ParseAttributeVec5(context, "ksizes", &ksizes_);
+ ParseAttributeVec5(context, "strides", &strides_);
+ // ParseAttributeVec5(context, "rates", &rates_);
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_planes, in_rows, in_cols, channels ]
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input.shape().DebugString()));
+
+ const int batch = input.dim_size(0);
+ const int in_planes = input.dim_size(1);
+ const int in_rows = input.dim_size(2);
+ const int in_cols = input.dim_size(3);
+ const int depth = input.dim_size(4);
+
+ const int ksize_planes = ksizes_[1];
+ const int ksize_rows = ksizes_[2];
+ const int ksize_cols = ksizes_[3];
+
+ const int stride_planes = strides_[1];
+ const int stride_rows = strides_[2];
+ const int stride_cols = strides_[3];
+
+ /*
+ // TODO(hsgkim): enable rates
+ // Rates are disabled as of now due to Eigen's definitions of
+ // `extract_volume_patch` functions; none of them accept rates
+ // as its argument and rates are fixed to (1, 1, 1, 1, 1). A
+ // workaround has to be found for this.
+ // In order to enable rates, uncomment the following lines and use
+ // ksize_*_eff instead of ksize_* for the second argument of
+ // GetWindowedOutputSize calls.
+
+ const int rate_planes = rates_[1];
+ const int rate_rows = rates_[2];
+ const int rate_cols = rates_[3];
+
+ const int ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ int64 out_planes = 0, out_rows = 0, out_cols = 0;
+ int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_planes, ksize_planes, stride_planes,
+ padding_, &out_planes, &pad_planes));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_rows, ksize_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+
+ const std::vector<int64> out_sizes = {
+ batch, out_planes, out_rows, out_cols,
+ ksize_planes * ksize_rows * ksize_cols * depth};
+ TensorShape out_shape(out_sizes);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+
+ functor::ExtractVolumePatchesForward<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes,
+ ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols,
+ /* rate_planes, rate_rows, rate_cols, */
+ BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> ksizes_;
+ std::vector<int32> strides_;
+ // std::vector<int32> rates_;
+
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExtractVolumePatchesOp);
+};
+
+// Registration of the CPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<CPUDevice, T>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+// clang-format off
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
+ int patch_planes, int patch_rows, int patch_cols, \
+ int stride_planes, int stride_rows, int stride_cols, \
+ /* int rate_planes, int rate_rows, int rate_cols, */ \
+ const Eigen::PaddingType& padding, \
+ typename TTypes<T, 5>::Tensor output); \
+ extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
+// clang-format on
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<GPUDevice, T>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.h b/tensorflow/core/kernels/extract_volume_patches_op.h
new file mode 100644
index 0000000000..7e0502b770
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_volume_patch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct ExtractVolumePatchesForward {
+ void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
+ int patch_planes, int patch_rows, int patch_cols,
+ int stride_planes, int stride_rows, int stride_cols,
+ /* int rate_planes, int rate_rows, int rate_cols, */
+ const Eigen::PaddingType& padding,
+ typename TTypes<T, 5>::Tensor output) {
+ const int64 N = std::max(input.size(), output.size());
+ if (N <= std::numeric_limits<Index32>::max()) {
+ auto output_32bit = To32Bit(output);
+ output_32bit.device(d) =
+ To32Bit(input)
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output_32bit.dimensions());
+ } else {
+ output.device(d) =
+ input
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output.dimensions());
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
diff --git a/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
new file mode 100644
index 0000000000..c636493602
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+#define REGISTER(T) template struct ExtractVolumePatchesForward<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD
index 8bfa40304e..f2e0b2558f 100644
--- a/tensorflow/core/kernels/fuzzing/BUILD
+++ b/tensorflow/core/kernels/fuzzing/BUILD
@@ -43,4 +43,6 @@ tf_ops_fuzz_target_lib("example_proto_fast_parsing")
tf_ops_fuzz_target_lib("parse_tensor_op")
+tf_ops_fuzz_target_lib("decode_compressed")
+
tf_ops_fuzz_target_lib("decode_json_example")
diff --git a/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
new file mode 100644
index 0000000000..0a56f4b63f
--- /dev/null
+++ b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
+
+namespace tensorflow {
+namespace fuzzing {
+
+class FuzzDecodeCompressed : public FuzzStringInputOp {
+ void BuildGraph(const Scope& scope) override {
+ auto input =
+ tensorflow::ops::Placeholder(scope.WithOpName("input1"), DT_STRING);
+ auto d1 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d1"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType(""));
+ auto d2 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d2"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("ZLIB"));
+ auto d3 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d3"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("GZIP"));
+ Scope grouper =
+ scope.WithControlDependencies(std::vector<tensorflow::Operation>{
+ d1.output.op(), d2.output.op(), d3.output.op()});
+ (void)tensorflow::ops::NoOp(grouper.WithOpName("output"));
+ }
+};
+
+STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeCompressed);
+
+} // namespace fuzzing
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index cd2873bdca..7710cf93d6 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
index a88e9b0ddc..374a05850e 100644
--- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
@@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index bca1cff41c..2088c13586 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -77,9 +77,9 @@ static Status TensorListDeviceCopy(
return Status::OK();
}
-#define REGISTER_LIST_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy)
+#define REGISTER_LIST_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
+ TensorListDeviceCopy)
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -92,8 +92,7 @@ Status TensorListShape(const TensorList& t, TensorShape* s) {
return Status::OK();
}
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName,
- TensorListShape);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape);
bool TensorList::Decode(const VariantTensorData& data) {
tensors = data.tensors();
@@ -625,12 +624,11 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<CPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<CPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index c591226b76..a00bf700ca 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -94,11 +94,10 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bool);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<GPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<GPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 6b6a14e9a7..1ded012f3c 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -90,6 +91,59 @@ class PrintOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
+class PrintV2Op : public OpKernel {
+ public:
+ explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
+
+ auto output_stream_index =
+ std::find(std::begin(valid_output_streams_),
+ std::end(valid_output_streams_), output_stream_);
+
+ if (output_stream_index == std::end(valid_output_streams_)) {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
+ const string& msg = input_->scalar<string>()();
+
+ if (output_stream_ == "stdout") {
+ std::cout << msg << std::endl;
+ } else if (output_stream_ == "stderr") {
+ std::cerr << msg << std::endl;
+ } else if (output_stream_ == "log(info)") {
+ LOG(INFO) << msg << std::endl;
+ } else if (output_stream_ == "log(warning)") {
+ LOG(WARNING) << msg << std::endl;
+ } else if (output_stream_ == "log(error)") {
+ LOG(ERROR) << msg << std::endl;
+ } else {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)",
+ "log(warning)", "log(error)"};
+
+ private:
+ string output_stream_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
+
class TimestampOp : public OpKernel {
public:
explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
index 5e6958f364..a259d995fa 100644
--- a/tensorflow/core/kernels/logging_ops_test.cc
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -23,11 +23,33 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
+class PrintingV2GraphTest : public OpsTestBase {
+ protected:
+ Status Init(const string& output_stream = "log(warning)") {
+ TF_CHECK_OK(NodeDefBuilder("op", "PrintV2")
+ .Input(FakeInput(DT_STRING))
+ .Attr("output_stream", output_stream)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(PrintingV2GraphTest, StringSuccess) {
+ TF_ASSERT_OK(Init());
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
+ ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
+}
+
class PrintingGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type1, DataType input_type2, string msg = "",
diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h
index cc4b6941b9..62aa7d5c29 100644
--- a/tensorflow/core/kernels/mirror_pad_op.h
+++ b/tensorflow/core/kernels/mirror_pad_op.h
@@ -103,6 +103,7 @@ struct TensorEvaluator<const TensorMirrorPadOp<PaddingDimensions, ArgType>,
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = false,
+ PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = true,
RawAccess = false
diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc
new file mode 100644
index 0000000000..a055351337
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc
@@ -0,0 +1,407 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#include "third_party/intel_mkl_dnn/include/mkldnn.h"
+#include "tensorflow/core/util/mkl_util.h"
+#endif
+
+// TODO(ezhulenev): Add numerical tests that will compare results of default
+// (aka Eigen) convolutions with MKL convolutions.
+
+// -------------------------------------------------------------------------- //
+// Performance Benchmarks. //
+// -------------------------------------------------------------------------- //
+
+// Compare performance of default Tensorflow convolution kernels (Eigen) with
+// MKL kernels on CPU.
+
+// Before running these benchmarks configure OpenMP environment variables:
+// export KMP_BLOCKTIME=0
+// export OMP_NUM_THREADS=${num_threads}
+
+namespace tensorflow {
+
+struct Conv2DDimensions {
+ Conv2DDimensions(int n, int h, int w, int c, int fc, int fh, int fw)
+ : input_batches(n),
+ input_height(h),
+ input_width(w),
+ input_depth(c),
+ filter_count(fc),
+ filter_height(fh),
+ filter_width(fw) {}
+
+ int input_batches;
+ int input_height;
+ int input_width;
+ int input_depth;
+ int filter_count;
+ int filter_height;
+ int filter_width;
+};
+
+static Tensor GetRandomTensor(const TensorShape& shape) {
+ Tensor tensor(DT_FLOAT, TensorShape(shape));
+ tensor.flat<float>() = tensor.flat<float>().setRandom();
+ return tensor;
+}
+
+// Get a random Tensor for the Conv2D input.
+static Tensor GetRandomInputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a random Tensor for the Conv2D filter.
+static Tensor GetRandomFilterTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+// Get a random Tensor for the Conv2D output (assuming SAME padding).
+static Tensor GetRandomOutputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.filter_count});
+}
+
+// Get a Tensor encoding Conv2D input shape.
+static Tensor GetInputSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a Tensor encoding Conv2D filter shape.
+static Tensor GetFilterSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Tensor NonMklTensor() {
+ MklDnnShape non_mkl_shape;
+ non_mkl_shape.SetMklTensor(false);
+
+ auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
+ Tensor tensor(DT_UINT8, {size});
+
+ non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
+ size * sizeof(uint8));
+ return tensor;
+}
+#endif
+
+static Graph* DefaultConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d"), "Conv2D")
+ .Input(input)
+ .Input(filter)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("mkl_conv_2d"), "_MklConv2D")
+ .Input(input)
+ .Input(filter)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_input"), "Conv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_input"),
+ "_MklConv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_filter"), "Conv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) {
+ Graph* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_filter"),
+ "_MklConv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+#endif
+
+// Macro arguments names: --------------------------------------------------- //
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_NAME(p, type, N, H, W, C, FC, FH, FW) \
+ BM_CONCAT(BM_##p##_##type##_in_##N##_##H##_##W##_##C, _f_##FC##_##FH##_##FW)
+
+// Flops computation in these benchmarks are the same as in
+// eigen_benchmark_cpu_test.cc.
+
+#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (C); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+// ImageNet Convolutions ---------------------------------------------------- //
+
+BM_Conv2D(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2D(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2D(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2D(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2D(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2D(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2D(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdInput(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdInput(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdInput(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdFilter(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdFilter(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdFilter(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 7a64788448..82dfece4a2 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -75,7 +75,7 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> {
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
random::PhiloxRandom gen_copy = gen;
- // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to
+ // Skip takes units of 128 bits. +3 is so rounding doesn't lead to
// us using the same state in different batches.
gen_copy.Skip(start_row * (num_samples + 3) / 4);
random::SimplePhilox simple_philox(&gen_copy);
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 7bb403290d..fc1c9003aa 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -127,12 +127,12 @@ class PartitionedCallOp : public AsyncOpKernel {
optimization_options.graph = &graph;
optimization_options.flib_def = overlay_lib;
optimization_options.device_set = &device_set;
- Placer placer(graph.get(), &device_set);
OP_REQUIRES_OK_ASYNC(
ctx,
OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
done);
+ Placer placer(graph.get(), &device_set);
OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
OP_REQUIRES_OK_ASYNC(
ctx,
@@ -210,7 +210,7 @@ class PartitionedCallOp : public AsyncOpKernel {
TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
DataType dtype = attr_value->type();
if (dtype == DT_RESOURCE) {
- ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
+ const ResourceHandle& handle = args[index].flat<ResourceHandle>()(0);
node->set_assigned_device_name(handle.device());
}
}
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 5fb1c92f94..272aa3b4f5 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
@@ -82,6 +83,9 @@ class QueueBase : public QueueInterface {
// NOTE(mrry): This method is deprecated. Use
// `tensorflow::batch_util::CopySliceToElement()` defined in
// "./batch_util.h" instead.
+ ABSL_DEPRECATED(
+ "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
+ "\"./batch_util.h\" instead.")
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
int64 index);
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index c4d404259b..97ddc852f7 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -65,7 +65,7 @@ class FakeQueueOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
- ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0);
+ const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0);
handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index e37232539f..04a53697c0 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel {
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().DebugString()));
- // Verify that minval < maxval
+ // Allocate output, and exit early if possible
+ Tensor* output;
+ OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
+ if (output->NumElements() == 0) return;
+
+ // Verify that minval < maxval. This check intentionally happens after the
+ // early exit for empty output. Zero impossible things are fine.
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES(
@@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel {
Distribution;
Distribution dist(lo, hi);
- Tensor* output;
- OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 88b3c2ac76..bb8254eaac 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -21,11 +21,11 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
-#include "external/cub_archive/cub/warp/warp_reduce.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_segmented_reduce.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/warp/warp_reduce.cuh"
#include "cuda/include/cuComplex.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/lib/core/bits.h"
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 9cf953f4bf..8bfa44b2d0 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -50,6 +50,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
+
+REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
REGISTER_GPU_KERNELS(int64);
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index ebcfb673d1..26705a8d34 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -79,7 +79,7 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
void ReadVariableOp::Compute(OpKernelContext* ctx) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, 0);
+ const ResourceHandle& handle = HandleFromInput(ctx, 0);
const auto status = LookupResource(ctx, handle, &variable);
OP_REQUIRES(ctx, status.ok(),
errors::FailedPrecondition(
diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc
index 15a707a9c6..cded417986 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.cc
+++ b/tensorflow/core/kernels/reverse_sequence_op.cc
@@ -64,7 +64,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@@ -91,7 +91,7 @@ void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
}
template <>
@@ -127,6 +127,7 @@ class ReverseSequenceOp : public OpKernel {
auto seq_lens_t = seq_lens.vec<Tlen>();
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
+ if (!context->status().ok()) return;
const int input_dims = input.dims();
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index e0194605ce..2f8aede427 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -145,6 +145,7 @@ class ScatterNdUpdateOp : public OpKernel {
if (dtype_ == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock m(*v->mu());
DoCompute(c);
} else if (use_exclusive_lock_) {
diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc
new file mode 100644
index 0000000000..dc627ac77a
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::upper_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::lower_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+template <typename Device, typename T, typename OutType>
+class UpperBoundOp : public OpKernel {
+ public:
+ explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+template <typename Device, typename T, typename OutType>
+class LowerBoundOp : public OpKernel {
+ public:
+ explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/searchsorted_op.h b/tensorflow/core/kernels/searchsorted_op.h
new file mode 100644
index 0000000000..f075bf0fa2
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, typename OutType>
+struct UpperBoundFunctor {
+ // Searches for values in sorted_inputs and returns the greatest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+
+template <typename Device, typename T, typename OutType>
+struct LowerBoundFunctor {
+ // Searches for values in sorted_inputs and returns the lowest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+} // namespace functor
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
new file mode 100644
index 0000000000..263b5bf298
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
+template <typename T, typename OutType>
+__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::upper_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+
+template <typename T, typename OutType>
+__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::lower_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+} // namespace
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ UpperBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ LowerBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc
index 9cd590ae61..30cb1e0a7f 100644
--- a/tensorflow/core/kernels/shape_op_test.cc
+++ b/tensorflow/core/kernels/shape_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/abi.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -60,8 +61,7 @@ Status GetShapeFromKnownVecSize(const KnownVecSize& ks, TensorShape* s) {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE");
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE",
- GetShapeFromKnownVecSize);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize);
static void ExpectHasError(const Status& s, StringPiece substr) {
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
@@ -94,9 +94,9 @@ TEST_F(ShapeOpTest, Simple) {
Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs);
EXPECT_FALSE(s.ok());
ExpectHasError(
- s,
- "No unary variant shape function found for Variant type_name: "
- "NO KNOWN SHAPE");
+ s, strings::StrCat(
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(MakeTypeIndex<NoKnownShape>().name())));
}
{
diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc
index 393818730b..a4a59dbcbc 100644
--- a/tensorflow/core/kernels/split_lib_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc
@@ -54,6 +54,7 @@ void SplitCustom<Device, T>::operator()(
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
+TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 7cc3c532c9..11db72bfa3 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -49,7 +49,12 @@ class SplitOpBase : public OpKernel {
void ComputeEasyCases(OpKernelContext* context, bool* done) {
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
- const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const Tensor& split_dim_tensor = context->input(0);
+ OP_REQUIRES(
+ context, split_dim_tensor.shape().dims() == 0,
+ errors::InvalidArgument("split_dim must be a scalar but has rank ",
+ split_dim_tensor.shape().dims()));
+ const int32 split_dim_orig = split_dim_tensor.flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
const int32 num_split = num_outputs();
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 65296f61fd..add4afafc9 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -131,10 +131,8 @@ class Stack : public ResourceBase {
};
Status GetStack(OpKernelContext* ctx, Stack** stack) {
- string key;
if (ctx->input_dtype(0) == DT_RESOURCE) {
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
- key = resource.name();
+ return LookupResource(ctx, HandleFromInput(ctx, 0), stack);
} else {
Tensor Tstack_handle = ctx->mutable_input(0, false);
if (Tstack_handle.NumElements() != 2) {
@@ -144,18 +142,18 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) {
}
const string& container = Tstack_handle.flat<string>()(0);
const string& stack_name = Tstack_handle.flat<string>()(1);
- key = strings::StrCat(container, stack_name);
- }
- ResourceMgr* rm = ctx->resource_manager();
- if (rm == nullptr) {
- return errors::Internal("No resource manager.");
- }
- auto* step_container = ctx->step_container();
- if (step_container == nullptr) {
- return errors::Internal("No step container.");
+ string key = strings::StrCat(container, stack_name);
+ ResourceMgr* rm = ctx->resource_manager();
+ if (rm == nullptr) {
+ return errors::Internal("No resource manager.");
+ }
+ auto* step_container = ctx->step_container();
+ if (step_container == nullptr) {
+ return errors::Internal("No step container.");
+ }
+ TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
+ return Status::OK();
}
- TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
- return Status::OK();
}
std::atomic<int64> Stack::stack_counter{0};
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 7b537fef5b..f0575de4d9 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -306,6 +306,7 @@ class StridedSliceAssignOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000000..e4a1887f8d
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+ explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string template_;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+ split_template_ = absl::StrSplit(template_, placeholder_);
+ int64 num_placeholders = split_template_.size() - 1;
+ OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+ errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", ctx->num_inputs())));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* formatted_string = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+ string msg;
+ strings::StrAppend(&msg, split_template_[0].c_str());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+ strings::StrAppend(&msg, split_template_[i + 1].c_str());
+ }
+
+ formatted_string->scalar<string>()() = msg;
+ }
+
+ private:
+ int32 summarize_ = 0;
+ string placeholder_;
+ std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+ StringFormatOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc
new file mode 100644
index 0000000000..13130a5797
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class StringFormatGraphTest : public OpsTestBase {
+ protected:
+ Status Init(int num_inputs, DataType input_type,
+ const string& template_ = "%s", const string& placeholder = "%s",
+ int summarize = 3) {
+ TF_CHECK_OK(NodeDefBuilder("op", "StringFormat")
+ .Input(FakeInput(num_inputs, input_type))
+ .Attr("template", template_)
+ .Attr("placeholder", placeholder)
+ .Attr("summarize", summarize)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(StringFormatGraphTest, Int32Success_7) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s"));
+
+ AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(StringFormatGraphTest, Int32Success_3_3) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1));
+
+ AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..."
+ "\n [7 ... 9]]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc
index a6829b29d9..435a7abdca 100644
--- a/tensorflow/core/kernels/string_length_op.cc
+++ b/tensorflow/core/kernels/string_length_op.cc
@@ -14,13 +14,18 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/string_util.h"
namespace tensorflow {
namespace {
class StringLengthOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
@@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel {
auto src = input.flat<string>();
auto dst = output->flat<int32>();
- for (int n = 0; n < src.size(); ++n) {
- dst(n) = src(n).size();
+ switch (unit_) {
+ case CharUnit::BYTE:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = src(n).size();
+ }
+ break;
+ case CharUnit::UTF8_CHAR:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = UTF8StrLen(src(n));
+ }
+ break;
}
}
+
+ private:
+ CharUnit unit_ = CharUnit::BYTE;
};
REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
new file mode 100644
index 0000000000..3a9803a052
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.cc
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/string_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace {
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+} // namespace
+
+namespace tensorflow {
+
+// Sets unit value based on str.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) {
+ if (str == "UTF8") {
+ *encoding = UnicodeEncoding::UTF8;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid encoding \"", str, "\": Should be one of: BYTE"));
+ }
+ return Status::OK();
+}
+
+// Sets unit value based on str.
+Status ParseCharUnit(const string& str, CharUnit* unit) {
+ if (str == "BYTE") {
+ *unit = CharUnit::BYTE;
+ } else if (str == "UTF8_CHAR") {
+ *unit = CharUnit::UTF8_CHAR;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR"));
+ }
+ return Status::OK();
+}
+
+// Return the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string) {
+ const int32 byte_size = string.size();
+ const char* const end = string.data() + byte_size;
+ const char* ptr = string.data();
+ int32 skipped_count = 0;
+ while (ptr < end) {
+ skipped_count += IsTrailByte(*ptr++) ? 1 : 0;
+ }
+ const int32 result = byte_size - skipped_count;
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
new file mode 100644
index 0000000000..390cf57702
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.h
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Enumeration for unicode encodings. Used by ops such as
+// tf.strings.unicode_encode and tf.strings.unicode_decode.
+// TODO(edloper): Add support for:
+// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE
+enum class UnicodeEncoding { UTF8 };
+
+// Enumeration for character units. Used by string such as
+// tf.strings.length and tf.substr.
+// TODO(edloper): Add support for: UTF32_CHAR, etc.
+enum class CharUnit { BYTE, UTF8_CHAR };
+
+// Sets `encoding` based on `str`.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
+
+// Sets `unit` value based on `str`.
+Status ParseCharUnit(const string& str, CharUnit* unit);
+
+// Returns the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 22e45918a0..07f1d6e767 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstddef>
+#include <cstdlib>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -25,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@@ -64,26 +68,28 @@ class SubstrOp : public OpKernel {
const T len =
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
} else {
// Perform Op element-wise with tensor pos/len
auto pos_flat = pos_tensor.flat<T>();
auto len_flat = len_tensor.flat<T>();
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
}
} else {
@@ -142,14 +148,16 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
- string in = input_bcast(i);
+ StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, input_bcast(i).size() + 1),
+ context,
+ FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
break;
}
@@ -192,16 +200,18 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
for (int j = 0; j < output_shape.dim_size(1); ++j) {
- string in = input_bcast(i, j);
+ StringPiece in(input_bcast(i, j));
const T pos =
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1),
- errors::InvalidArgument(
- "pos ", pos, " out of range for ", "string b'",
- in, "' at index (", i, ", ", j, ")"));
- output(i, j) = in.substr(pos, len);
+ OP_REQUIRES(
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (", i,
+ ", ", j, ")"));
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i, j).assign(sub_in.data(), sub_in.size());
}
}
break;
@@ -213,6 +223,16 @@ class SubstrOp : public OpKernel {
}
}
}
+
+ private:
+ // This adjusts the requested position. Note it does not perform any bound
+ // checks.
+ T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ if (pos_requested < 0) {
+ return s.size() + pos_requested;
+ }
+ return pos_requested;
+ }
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
new file mode 100644
index 0000000000..2e07050260
--- /dev/null
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor position(DT_INT32, TensorShape({}));
+ position.flat<int32>().setConstant(pos);
+ Tensor length(DT_INT32, TensorShape({}));
+ length.flat<int32>().setConstant(len);
+
+ TF_CHECK_OK(NodeBuilder("substr_op", "Substr")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, position))
+ .Input(test::graph::Constant(g, length))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_Substr(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
+ 256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 765467bc1e..0e6c0ddccc 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -62,7 +62,8 @@ TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
}
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index e8dc4fad21..384a63e945 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -81,7 +81,8 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 2ec2651c04..a97a71b344 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -259,6 +259,7 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU),
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -290,7 +291,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
}
} else {
container = "_tensor_arrays";
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
+ const auto& resource = ctx->input(0).flat<ResourceHandle>()(0);
if (StringPiece(resource.name()).substr(0, container.size()) !=
container) {
return errors::InvalidArgument("Wrong input container. ",
@@ -576,6 +577,7 @@ TF_CALL_ALL_TYPES(REGISTER_READ)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -1218,6 +1220,7 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index ca296d5aa0..2fbe1fe7cb 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <cmath>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_segmented_radix_sort.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index d3c4f62071..83b83fcdb9 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -21,6 +21,7 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
if (ctx->input_dtype(input) == DT_RESOURCE) {
Var* var;
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
+ core::ScopedUnref scoped_unref(var);
return var->mu();
} else {
ctx->CtxFailureWithWarning(
diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc
index 62e814ff77..8d839ba85a 100644
--- a/tensorflow/core/kernels/unravel_index_op.cc
+++ b/tensorflow/core/kernels/unravel_index_op.cc
@@ -97,10 +97,12 @@ class UnravelIndexOp : public OpKernel {
auto output = output_tensor->matrix<Tidx>();
- Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}};
- Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()});
- Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}};
- Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1});
+ Eigen::array<Eigen::Index, 2> reshape{{dims_tensor.NumElements(), 1}};
+ Eigen::array<Eigen::Index, 2> bcast({1, indices_tensor.NumElements()});
+ Eigen::array<Eigen::Index, 2> indices_reshape{
+ {1, indices_tensor.NumElements()}};
+ Eigen::array<Eigen::Index, 2> indices_bcast(
+ {dims_tensor.NumElements(), 1});
output = indices_tensor.vec<Tidx>()
.reshape(indices_reshape)
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 8879d9dd4c..2255597651 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -21,10 +21,10 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_select.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_select.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h
index 49f74ff47f..eb0ff555a5 100644
--- a/tensorflow/core/lib/core/status.h
+++ b/tensorflow/core/lib/core/status.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index e7b17c9b36..6edff139ae 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -26,13 +26,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
#define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
-#include <assert.h>
-#include <stddef.h>
-#include <string.h>
-#include <iosfwd>
-#include <string>
#include "absl/strings/string_view.h"
-#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index 99684ae47b..9ccd911b0e 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -17,6 +17,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
@@ -120,6 +121,54 @@ void ThreadPool::Schedule(std::function<void()> fn) {
impl_->Schedule(std::move(fn));
}
+int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
+ const int64 block_size, const int64 total) {
+ if (block_size <= 0 || total <= 1 || total <= block_size ||
+ NumThreads() == 1) {
+ return 1;
+ }
+ return (total + block_size - 1) / block_size;
+}
+
+// This functionality is similar to parallelFor, except that reasoning about
+// the number of shards used is significantly easier.
+void ThreadPool::TransformRangeConcurrently(
+ const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn) {
+ const int num_shards_used =
+ NumShardsUsedByTransformRangeConcurrently(block_size, total);
+ if (num_shards_used == 1) {
+ fn(0, total);
+ return;
+ }
+
+ // Adapted from Eigen's parallelFor implementation.
+ BlockingCounter counter(num_shards_used);
+ std::function<void(int64, int64)> handle_range =
+ [=, &handle_range, &counter, &fn](int64 first, int64 last) {
+ while (last - first > block_size) {
+ // Find something near the midpoint which is a multiple of block size.
+ const int64 mid = first + ((last - first) / 2 + block_size - 1) /
+ block_size * block_size;
+ Schedule([=, &handle_range]() { handle_range(mid, last); });
+ last = mid;
+ }
+ // Single block or less, execute directly.
+ fn(first, last);
+ counter.DecrementCount(); // The shard is done.
+ };
+ if (num_shards_used <= NumThreads()) {
+ // Avoid a thread hop by running the root of the tree and one block on the
+ // main thread.
+ handle_range(0, total);
+ } else {
+ // Execute the root in the thread pool to avoid running work on more than
+ // numThreads() threads.
+ Schedule([=, &handle_range]() { handle_range(0, total); });
+ }
+ counter.Wait();
+}
+
void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn) {
impl_->ParallelFor(total, cost_per_unit, std::move(fn));
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index 74df7c84a4..e14ad7ac64 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -59,6 +59,20 @@ class ThreadPool {
// Schedules fn() for execution in the pool of threads.
void Schedule(std::function<void()> fn);
+ // Requires 0 < block_size <= total.
+ // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the
+ // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total)
+ // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...).
+ // Note that when there aren't enough threads in the pool to achieve full
+ // parallelism, function calls will be automatically queued.
+ void TransformRangeConcurrently(const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn);
+
+ // Returns the number of threads spawned by calling TransformRangeConcurrently
+ // with these parameters.
+ int NumShardsUsedByTransformRangeConcurrently(const int64 block_size,
+ const int64 total);
+
// ParallelFor shards the "total" units of work assuming each unit of work
// having roughly "cost_per_unit" cost, in cycles. Each unit of work is
// indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index 320f3ebb83..db996b783f 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -61,6 +61,67 @@ TEST(ThreadPool, DoWork) {
}
}
+void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
+ mutex mu;
+ int64 num_shards = 0;
+ int64 num_done_work = 0;
+ std::vector<bool> work(total, false);
+ threads->TransformRangeConcurrently(
+ block_size, total,
+ [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) {
+ VLOG(1) << "Shard [" << start << "," << end << ")";
+ EXPECT_GE(start, 0);
+ EXPECT_LE(end, total);
+ mutex_lock l(mu);
+ ++num_shards;
+ for (; start < end; ++start) {
+ EXPECT_FALSE(work[start]); // No duplicate
+ ++num_done_work;
+ work[start] = true;
+ }
+ });
+ LOG(INFO) << block_size << " " << total;
+ const int64 num_workers = (total + block_size - 1) / block_size;
+ EXPECT_EQ(num_done_work, total);
+ if (num_workers < threads->NumThreads()) {
+ // If the intention is to limit the parallelism explicitly, we'd
+ // better honor it. Ideally, even if per_thread_max_parallelism >
+ // num_workers, we should expect that Shard() implementation do
+ // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+ // tends to over-shard.
+ EXPECT_LE(num_shards, 1 + num_workers);
+ }
+}
+
+// Adapted from work_sharder_test.cc
+TEST(SparseUtilsTest, TransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) {
+ for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+ const int64 total = block_size + diff;
+ RunSharding(block_size, total, &threads);
+ }
+ }
+}
+
+TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 3 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 4 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 5 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 6 /* total */));
+ EXPECT_EQ(3, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 7 /* total */));
+ EXPECT_EQ(7, threads.NumShardsUsedByTransformRangeConcurrently(
+ 1 /* block_size */, 7 /* total */));
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 0 /* block_size */, 7 /* total */));
+}
+
TEST(ThreadPool, ParallelFor) {
Context outer_context(ContextKind::kThread);
// Make ParallelFor use as many threads as possible.
diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h
index e2927689d2..117b6a0bb8 100644
--- a/tensorflow/core/lib/io/block_builder.h
+++ b/tensorflow/core/lib/io/block_builder.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace table {
diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h
index e3649fd0c9..38fb0c5d86 100644
--- a/tensorflow/core/lib/io/path.h
+++ b/tensorflow/core/lib/io/path.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_LIB_IO_PATH_H_
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace io {
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index f93ebea771..e22adcd569 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -108,6 +108,59 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
return Status::OK();
}
+Status RecordReader::GetMetadata(Metadata* md) {
+ if (!md) {
+ return errors::InvalidArgument(
+ "Metadata object call to GetMetadata() was null");
+ }
+
+ // Compute the metadata of the TFRecord file if not cached.
+ if (!cached_metadata_) {
+ TF_RETURN_IF_ERROR(input_stream_->Reset());
+
+ int64 data_size = 0;
+ int64 entries = 0;
+
+ // Within the loop, we always increment offset positively, so this
+ // loop should be guaranteed to either return after reaching EOF
+ // or encountering an error.
+ uint64 offset = 0;
+ string record;
+ while (true) {
+ // Read header, containing size of data.
+ Status s = ReadChecksummed(offset, sizeof(uint64), &record);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) {
+ // We should reach out of range when the record file is complete.
+ break;
+ }
+ return s;
+ }
+
+ // Read the length of the data.
+ const uint64 length = core::DecodeFixed64(record.data());
+
+ // Skip reading the actual data since we just want the number
+ // of records and the size of the data.
+ TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
+ offset += kHeaderSize + length + kFooterSize;
+
+ // Increment running stats.
+ data_size += length;
+ ++entries;
+ }
+
+ cached_metadata_.reset(new Metadata());
+ cached_metadata_->stats.entries = entries;
+ cached_metadata_->stats.data_size = data_size;
+ cached_metadata_->stats.file_size =
+ data_size + (kHeaderSize + kFooterSize) * entries;
+ }
+
+ md->stats = cached_metadata_->stats;
+ return Status::OK();
+}
+
Status RecordReader::ReadRecord(uint64* offset, string* record) {
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index 11af1366b0..17444660d4 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -66,6 +66,18 @@ class RecordReader {
static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static const size_t kFooterSize = sizeof(uint32);
+ // Statistics (sizes are in units of bytes)
+ struct Stats {
+ int64 file_size = -1;
+ int64 data_size = -1;
+ int64 entries = -1; // Number of values
+ };
+
+ // Metadata for the TFRecord file.
+ struct Metadata {
+ Stats stats;
+ };
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
@@ -79,6 +91,17 @@ class RecordReader {
// OUT_OF_RANGE for end of file, or something else for an error.
Status ReadRecord(uint64* offset, string* record);
+ // Return the metadata of the Record file.
+ //
+ // The current implementation scans the file to completion,
+ // skipping over the data regions, to extract the metadata once
+ // on the first call to GetStats(). An improved implementation
+ // would change RecordWriter to write the metadata into TFRecord
+ // so that GetMetadata() could be a const method.
+ //
+ // 'metadata' must not be nullptr.
+ Status GetMetadata(Metadata* md);
+
private:
Status ReadChecksummed(uint64 offset, size_t n, string* result);
@@ -86,6 +109,8 @@ class RecordReader {
std::unique_ptr<InputStreamInterface> input_stream_;
bool last_read_failed_;
+ std::unique_ptr<Metadata> cached_metadata_;
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
index 13bea1f8f1..a88d34d293 100644
--- a/tensorflow/core/lib/io/record_reader_writer_test.cc
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -147,6 +147,13 @@ TEST(RecordReaderWriterTest, TestBasics) {
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("defg", record);
+
+ io::RecordReader::Metadata md;
+ TF_ASSERT_OK(reader.GetMetadata(&md));
+ EXPECT_EQ(2, md.stats.entries);
+ EXPECT_EQ(7, md.stats.data_size);
+ // Two entries have 16 bytes of header/footer each.
+ EXPECT_EQ(39, md.stats.file_size);
}
}
}
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
index da514bd21c..946d7188d3 100644
--- a/tensorflow/core/lib/io/recordio_test.cc
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -58,7 +58,7 @@ class StringDest : public WritableFile {
Status Close() override { return Status::OK(); }
Status Flush() override { return Status::OK(); }
Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& slice) override {
+ Status Append(StringPiece slice) override {
contents_->append(slice.data(), slice.size());
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc
index 877ac40f1c..9cebbf40c6 100644
--- a/tensorflow/core/lib/io/table_test.cc
+++ b/tensorflow/core/lib/io/table_test.cc
@@ -98,7 +98,7 @@ class StringSink : public WritableFile {
Status Flush() override { return Status::OK(); }
Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
contents_.append(data.data(), data.size());
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc
index 84b47c171f..cba139e6ad 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.cc
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc
@@ -143,7 +143,7 @@ Status ZlibOutputBuffer::FlushOutputBufferToFile() {
return Status::OK();
}
-Status ZlibOutputBuffer::Append(const StringPiece& data) {
+Status ZlibOutputBuffer::Append(StringPiece data) {
// If there is sufficient free space in z_stream_input_ to fit data we
// add it there and return.
// If there isn't enough space we deflate the existing contents of
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h
index 3d86d89a99..ccad2fda44 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.h
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.h
@@ -62,7 +62,7 @@ class ZlibOutputBuffer : public WritableFile {
// to file when the buffer is full.
//
// To immediately write contents to file call `Flush()`.
- Status Append(const StringPiece& data) override;
+ Status Append(StringPiece data) override;
// Deflates any cached input and writes all output to file.
Status Flush() override;
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc
index 50ed8bdb3b..f7a359eb5b 100644
--- a/tensorflow/core/lib/jpeg/jpeg_mem.cc
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc
@@ -152,7 +152,9 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
cinfo.scale_denom = ratio;
cinfo.dct_method = flags.dct_method;
- jpeg_start_decompress(&cinfo);
+ // Determine the output image size before attempting decompress to prevent
+ // OOM'ing doing the decompress
+ jpeg_calc_output_dimensions(&cinfo);
int64 total_size = static_cast<int64>(cinfo.output_height) *
static_cast<int64>(cinfo.output_width);
@@ -170,6 +172,8 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
return nullptr;
}
+ jpeg_start_decompress(&cinfo);
+
JDIMENSION target_output_width = cinfo.output_width;
JDIMENSION target_output_height = cinfo.output_height;
JDIMENSION skipped_scanlines = 0;
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index c204d52cfe..9e4e1989dd 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace monitoring {
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index 756e5c2af8..bc4365e439 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace monitoring {
diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h
index bb5d20fb68..c876c5156a 100644
--- a/tensorflow/core/lib/png/png_io.h
+++ b/tensorflow/core/lib/png/png_io.h
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/png.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace png {
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc
index 36d939e061..c536b5688e 100644
--- a/tensorflow/core/lib/wav/wav_io.cc
+++ b/tensorflow/core/lib/wav/wav_io.cc
@@ -232,6 +232,11 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
"Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
}
TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
+ if (*channel_count < 1) {
+ return errors::InvalidArgument(
+ "Bad number of channels for WAV: Expected at least 1, but got ",
+ *channel_count);
+ }
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
uint32 bytes_per_second;
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 7dbb18aa5d..442686c92a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2595,6 +2595,116 @@ REGISTER_OP("ExtractImagePatches")
// --------------------------------------------------------------------------
+// To enable rates, uncomment all lines commented below and use ksize_*_eff
+// as the second parameter of all GetWindowedOutputSizeVerbose calls instead
+// of ksize_*.
+REGISTER_OP("ExtractVolumePatches")
+ .Input("input: T")
+ .Output("patches: T")
+ .Attr("ksizes: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ /* .Attr("rates: list(int) >= 5") */
+ .Attr("T: realnumbertype")
+ .Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
+
+ std::vector<int32> ksizes;
+ TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
+ if (ksizes.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the ksizes attribute to contain 5 "
+ "values, but got: ",
+ ksizes.size());
+ }
+
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ if (strides.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the stride attribute to contain 5 "
+ "values, but got: ",
+ strides.size());
+ }
+
+ /*
+ // TODO(hsgkim): Enable rates.
+ // See extract_volume_patches_op.cc for why rates are disabled now.
+
+ std::vector<int32> rates;
+ TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
+ if (rates.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the rates attribute to contain 5 "
+ "values, but got: ",
+ rates.size());
+ }
+ */
+
+ int32 ksize_planes = ksizes[1];
+ int32 ksize_rows = ksizes[2];
+ int32 ksize_cols = ksizes[3];
+
+ int32 stride_planes = strides[1];
+ int32 stride_rows = strides[2];
+ int32 stride_cols = strides[3];
+
+ /*
+ int32 rate_planes = rates[1];
+ int32 rate_rows = rates[2];
+ int32 rate_cols = rates[3];
+
+ int32 ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
+ DimensionHandle output_depth_dim;
+ TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
+ ksize_planes * ksize_rows * ksize_cols,
+ &output_depth_dim));
+
+ if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
+ !c->ValueKnown(in_cols_dim)) {
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim, output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ }
+ auto in_planes = c->Value(in_planes_dim);
+ auto in_rows = c->Value(in_rows_dim);
+ auto in_cols = c->Value(in_cols_dim);
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ int64 output_planes, output_rows, output_cols;
+ int64 padding_before, padding_after;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_planes, ksize_planes, stride_planes, padding, &output_planes,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_rows, ksize_rows, stride_rows, padding, &output_rows,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_cols, ksize_cols, stride_cols, padding, &output_cols,
+ &padding_before, &padding_after));
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
+ output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ });
+
+// --------------------------------------------------------------------------
+
REGISTER_OP("Bitcast")
.Input("input: T")
.Output("output: type")
@@ -2916,6 +3026,34 @@ Status ScatterNdShape(InferenceContext* c) {
} // namespace
+REGISTER_OP("UpperBound")
+ .Input("sorted_inputs: T")
+ .Input("values: T")
+ .Output("output: out_type")
+ .Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ });
+
+REGISTER_OP("LowerBound")
+ .Input("sorted_inputs: T")
+ .Input("values: T")
+ .Output("output: out_type")
+ .Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ });
+
REGISTER_OP("ScatterNd")
.Input("indices: Tindices")
.Input("updates: T")
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 01452b3e85..b8cf538554 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -22,6 +22,10 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
REGISTER_OP("IsBoostedTreesEnsembleInitialized")
@@ -176,6 +180,8 @@ REGISTER_OP("BoostedTreesMakeStatsSummary")
return Status::OK();
});
+// TODO(nponomareva): when/if creating the new op for unbucketized data, rename
+// bucketized_features to features.
REGISTER_OP("BoostedTreesPredict")
.Input("tree_ensemble_handle: resource")
.Input("bucketized_features: num_bucketized_features * int32")
@@ -354,4 +360,125 @@ REGISTER_OP("BoostedTreesCenterBias")
return Status::OK();
});
+REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
+
+REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
+ .Attr("max_elements: int = 1099511627776") // 1 << 40
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("epsilon: float")
+ .Input("num_streams: int64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesMakeQuantileSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("example_weights: float")
+ .Input("epsilon: float")
+ .Output("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle example_weights_shape;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features), 1, &example_weights_shape));
+ for (int i = 0; i < num_features; ++i) {
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(example_weights_shape, 0),
+ &unused_dim));
+ // the columns are value, weight, min_rank, max_rank.
+ c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
+ }
+ // epsilon must be a scalar.
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features + 1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ // resource handle must be a scalar.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // each summary must be rank 2.
+ for (int i = 1; i < num_features + 1; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
+ .Attr("generate_quantiles: bool = False")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("num_buckets: int64")
+ .SetShapeFn([](InferenceContext* c) {
+ // All the inputs are scalars.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("bucket_boundaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ shape_inference::ShapeHandle unused_input;
+ // resource handle must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->Vector(c->UnknownDim()));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesBucketize")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("bucket_boundaries: num_features * float")
+ .Output("buckets: num_features * int32")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ for (int i = 0; i < num_features; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(c->input(0), 0), &unused_dim));
+ }
+ // Bucketized result should have same dimension as input.
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1}));
+ }
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 34e6b5560b..86d4c6b421 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -11360,6 +11360,29 @@ op {
is_commutative: true
}
op {
+ name: "BoostedTreesBucketize"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ output_arg {
+ name: "buckets"
+ type: DT_INT32
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
@@ -11469,6 +11492,29 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesCreateQuantileStreamResource"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "num_streams"
+ type: DT_INT64
+ }
+ attr {
+ name: "max_elements"
+ type: "int"
+ default_value {
+ i: 1099511627776
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesDeserializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -11562,6 +11608,32 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesMakeQuantileSummaries"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesMakeStatsSummary"
input_arg {
name: "node_ids"
@@ -11631,6 +11703,83 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceFlush"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "num_buckets"
+ type: DT_INT64
+ }
+ attr {
+ name: "generate_quantiles"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceHandleOp"
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesSerializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -21753,6 +21902,59 @@ op {
}
}
op {
+ name: "ExtractVolumePatches"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "patches"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksizes"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FFT"
input_arg {
name: "input"
@@ -27192,6 +27394,18 @@ op {
is_stateful: true
}
op {
+ name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "is_initialized"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "IsFinite"
input_arg {
name: "x"
@@ -29227,6 +29441,38 @@ op {
}
}
op {
+ name: "LowerBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MakeIterator"
input_arg {
name: "dataset"
@@ -34950,6 +35196,29 @@ op {
}
}
op {
+ name: "ModelDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
@@ -35057,6 +35326,134 @@ op {
is_commutative: true
}
op {
+ name: "MultiDeviceIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "devices"
+ type: "list(string)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorFromStringHandle"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorGetNextFromShard"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "shard_num"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorInit"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "max_buffer_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorToStringHandle"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
@@ -38696,6 +39093,30 @@ op {
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -59664,6 +60085,29 @@ op {
}
}
op {
+ name: "Softplus"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SoftplusGrad"
input_arg {
name: "gradients"
@@ -59800,6 +60244,33 @@ op {
}
}
op {
+ name: "SoftplusGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Softsign"
input_arg {
name: "features"
@@ -59920,6 +60391,29 @@ op {
}
}
op {
+ name: "Softsign"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SoftsignGrad"
input_arg {
name: "gradients"
@@ -60056,6 +60550,33 @@ op {
}
}
op {
+ name: "SoftsignGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SpaceToBatch"
input_arg {
name: "input"
@@ -70004,6 +70525,43 @@ op {
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
@@ -70040,6 +70598,30 @@ op {
}
}
op {
+ name: "StringLength"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "StringSplit"
input_arg {
name: "input"
@@ -75083,6 +75665,38 @@ op {
is_stateful: true
}
op {
+ name: "UpperBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "VarHandleOp"
output_arg {
name: "resource"
@@ -75418,9 +76032,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "shift"
type: DT_INT64
}
+ input_arg {
+ name: "stride"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index f78f7a897a..f84142c992 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -37,7 +37,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-
REGISTER_OP("CudnnRNNParamsSize")
.Input("num_layers: int32")
.Input("num_units: int32")
@@ -52,11 +51,16 @@ REGISTER_OP("CudnnRNNParamsSize")
.Attr("seed2: int = 0")
.Output("params_size: S")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ // num_layers, num_units, and input_size should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+
c->set_output(0, c->Vector(1));
return Status::OK();
});
-
REGISTER_OP("CudnnRNN")
.Input("input: T")
.Input("input_h: T")
@@ -248,7 +252,6 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
return Status::OK();
});
-
REGISTER_OP("CudnnRNNCanonicalToParams")
.Input("num_layers: int32")
.Input("num_units: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 2dd867561b..13c3b933f4 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -26,7 +26,16 @@ namespace tensorflow {
TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
ShapeInferenceTestOp op("CudnnRNNParamsSize");
- INFER_OK(op, "[1];[1];[1]", "[1]");
+ INFER_OK(op, "[];[];[]", "[1]");
+ INFER_OK(op, "?;[];[]", "[1]");
+ INFER_OK(op, "[];?;[]", "[1]");
+ INFER_OK(op, "[];[];?", "[1]");
+ INFER_OK(op, "[];?;?", "[1]");
+ INFER_OK(op, "?;?;?", "[1]");
+
+ INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]");
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 9d2b3af51d..1ada623cf5 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -396,14 +396,20 @@ REGISTER_OP("FilterByLastComponentDataset")
REGISTER_OP("WindowDataset")
.Input("input_dataset: variant")
- .Input("window_size: int64")
+ .Input("size: int64")
+ .Input("shift: int64")
+ .Input("stride: int64")
+ .Input("drop_remainder: bool")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
- // batch_size should be a scalar.
+ // size, shift, stride, and drop_remainder should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
return shape_inference::ScalarShape(c);
});
@@ -873,6 +879,13 @@ REGISTER_OP("IteratorGetNextAsOptional")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ModelDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
.Output("output: output_types")
@@ -919,4 +932,41 @@ REGISTER_OP("MapDefun")
return Status::OK();
});
+REGISTER_OP("MultiDeviceIterator")
+ .Output("handle: resource")
+ .Attr("devices: list(string) >= 1")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorInit")
+ .Input("dataset: variant")
+ .Input("multi_device_iterator: resource")
+ .Input("max_buffer_size: int64")
+ .Output("incarnation_id: int64")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
+ .Input("multi_device_iterator: resource")
+ .Input("shard_num: int32")
+ .Input("incarnation_id: int64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
+REGISTER_OP("MultiDeviceIteratorToStringHandle")
+ .Input("multi_device_iterator: resource")
+ .Output("string_handle: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorFromStringHandle")
+ .Input("string_handle: string")
+ .Output("multi_device_iterator: resource")
+ .Attr("output_types: list(type) >= 0 = []")
+ .Attr("output_shapes: list(shape) >= 0 = []")
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 639d211767..2034d3601b 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -20,6 +20,8 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("Assert")
.Input("condition: bool")
.Input("data: T")
@@ -44,6 +46,23 @@ REGISTER_OP("Print")
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
+REGISTER_OP("PrintV2")
+ .Input("input: string")
+ .SetIsStateful()
+ .Attr(
+ "output_stream: {'stdout', 'stderr', 'log(info)', "
+ "'log(warning)', 'log(error)'} = 'stderr'")
+ .SetShapeFn([](InferenceContext* c) {
+ // Make sure that the input is a scalar.
+ if (c->Rank(c->input(0)) != 0) {
+ return errors::InvalidArgument("input must be a scalar, but has rank: ",
+ c->Rank(c->input(0)));
+ }
+ return Status::OK();
+ });
+
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("PrintV2");
+
// ----------------------------------------------------------------------------
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 6c318e358a..6191a88e5b 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1024,32 +1024,30 @@ REGISTER_OP("SeluGrad")
.Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
-// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("SoftplusGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
-// TODO(b/111515541): change T to {half, bfloat16, float, double}
REGISTER_OP("Softsign")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("SoftsignGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
// --------------------------------------------------------------------------
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index a2fc76c8b6..3ae4f1a59e 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4272,6 +4272,29 @@ op {
is_commutative: true
}
op {
+ name: "BoostedTreesBucketize"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ output_arg {
+ name: "buckets"
+ type: DT_INT32
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
@@ -4381,6 +4404,29 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesCreateQuantileStreamResource"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "num_streams"
+ type: DT_INT64
+ }
+ attr {
+ name: "max_elements"
+ type: "int"
+ default_value {
+ i: 1099511627776
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesDeserializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -4474,6 +4520,32 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesMakeQuantileSummaries"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesMakeStatsSummary"
input_arg {
name: "node_ids"
@@ -4543,6 +4615,83 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceFlush"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "num_buckets"
+ type: DT_INT64
+ }
+ attr {
+ name: "generate_quantiles"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceHandleOp"
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesSerializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -10038,6 +10187,59 @@ op {
}
}
op {
+ name: "ExtractVolumePatches"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "patches"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksizes"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FFT"
input_arg {
name: "input"
@@ -13162,6 +13364,18 @@ op {
is_stateful: true
}
op {
+ name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "is_initialized"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "IsFinite"
input_arg {
name: "x"
@@ -14443,6 +14657,38 @@ op {
}
}
op {
+ name: "LowerBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MakeIterator"
input_arg {
name: "dataset"
@@ -16628,6 +16874,29 @@ op {
}
}
op {
+ name: "ModelDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
@@ -16664,6 +16933,134 @@ op {
is_commutative: true
}
op {
+ name: "MultiDeviceIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "devices"
+ type: "list(string)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorFromStringHandle"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorGetNextFromShard"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "shard_num"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorInit"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "max_buffer_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorToStringHandle"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
@@ -19405,6 +19802,30 @@ op {
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -28361,18 +28782,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28396,18 +28809,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28427,18 +28832,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -28462,18 +28859,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -32619,6 +33008,43 @@ op {
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
@@ -32653,6 +33079,19 @@ op {
name: "output"
type: DT_INT32
}
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
}
op {
name: "StringSplit"
@@ -35838,6 +36277,38 @@ op {
is_stateful: true
}
op {
+ name: "UpperBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "VarHandleOp"
output_arg {
name: "resource"
@@ -36083,9 +36554,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "shift"
type: DT_INT64
}
+ input_arg {
+ name: "stride"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index 79ca96d249..eff453241d 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -343,10 +343,11 @@ REGISTER_OP("DecodeCSV")
// Validate the record_defaults inputs.
for (int i = 1; i < c->num_inputs(); ++i) {
ShapeHandle v;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &v));
- if (c->Value(c->Dim(v, 0)) > 1) {
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
return errors::InvalidArgument(
- "Shape of a default must be a length-0 or length-1 vector");
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
}
}
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc
index c65e66d1a8..ba594e400c 100644
--- a/tensorflow/core/ops/parsing_ops_test.cc
+++ b/tensorflow/core/ops/parsing_ops_test.cc
@@ -52,9 +52,12 @@ TEST(ParsingOpsTest, DecodeCSV_ShapeFn) {
INFER_OK(op, "[1,2,?,4];?;?", "in0;in0");
INFER_OK(op, "[1,2,?,4];[?];[?]", "in0;in0");
+ // Scalar defaults are ok
+ INFER_OK(op, "?;?;[]", "in0;in0");
+
// Check errors in the record_defaults inputs.
- INFER_ERROR("must be rank 1", op, "?;?;[]");
- INFER_ERROR("must be rank 1", op, "?;[];?");
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;?;[1,2]");
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;[3,4];?");
INFER_ERROR("Shape of a default must be", op, "?;?;[2]");
INFER_ERROR("Shape of a default must be", op, "?;[2];?");
}
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index ef8b15dc8a..da1d2a6432 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -102,6 +103,32 @@ REGISTER_OP("AsString")
.Attr("fill: string = ''")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("StringFormat")
+ .Input("inputs: T")
+ .Output("output: string")
+ .Attr("T: list(type) >= 0")
+ .Attr("template: string = '%s'")
+ .Attr("placeholder: string = '%s'")
+ .Attr("summarize: int = 3")
+ .SetShapeFn([](InferenceContext* c) {
+ string template_;
+ string placeholder;
+ TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
+ TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
+
+ std::vector<std::string> split_template;
+ split_template = absl::StrSplit(template_, placeholder);
+ int64 num_placeholders = split_template.size() - 1;
+ if (c->num_inputs() != num_placeholders) {
+ return errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", c->num_inputs()));
+ }
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("StringJoin")
.Input("inputs: N * string")
.Attr("N: int")
@@ -176,6 +203,7 @@ REGISTER_OP("StringStrip")
REGISTER_OP("StringLength")
.Input("input: string")
.Output("output: int32")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("EncodeBase64")
diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc
index e597a490d6..d7a13a3528 100644
--- a/tensorflow/core/platform/abi.cc
+++ b/tensorflow/core/platform/abi.cc
@@ -37,13 +37,13 @@ extern "C" char* __unDName(char* output_string, const char* name,
namespace tensorflow {
namespace port {
-std::string MaybeAbiDemangle(const char* name) {
+string MaybeAbiDemangle(const char* name) {
#if defined(_MSC_VER)
std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc,
std::free,
static_cast<unsigned short>(0))};
- return std::string(demangled.get() != nullptr ? demangled.get() : name);
+ return string(demangled.get() != nullptr ? demangled.get() : name);
#else
int status = 0;
std::unique_ptr<char, void (*)(void*)> res{
diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h
index 591e83b0c4..d1498a6a64 100644
--- a/tensorflow/core/platform/abi.h
+++ b/tensorflow/core/platform/abi.h
@@ -17,11 +17,12 @@ limitations under the License.
#define TENSORFLOW_CORE_PLATFORM_ABI_H_
#include <string>
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace port {
-std::string MaybeAbiDemangle(const char* name);
+string MaybeAbiDemangle(const char* name);
} // namespace port
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 8f959c018e..83ea8539ed 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -25,6 +25,7 @@ limitations under the License.
#ifdef _WIN32
#include <io.h> // for _mktemp
#endif
+#include "absl/base/macros.h"
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -63,7 +64,7 @@ constexpr int kGetChildrenDefaultPageSize = 1000;
// The HTTP response code "308 Resume Incomplete".
constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
// The environment variable that overrides the size of the readahead buffer.
-// DEPRECATED. Use GCS_BLOCK_SIZE_MB instead.
+ABSL_DEPRECATED("Use GCS_BLOCK_SIZE_MB instead.")
constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
// The environment variable that disables the GCS block cache for reads.
// This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
@@ -371,7 +372,7 @@ class GcsWritableFile : public WritableFile {
~GcsWritableFile() override { Close().IgnoreError(); }
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
TF_RETURN_IF_ERROR(CheckWritable());
sync_needed_ = true;
outfile_ << data;
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h
index 92aa72be89..941ab7ad65 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/cloud/retrying_file_system.h
@@ -177,7 +177,7 @@ class RetryingWritableFile : public WritableFile {
Close().IgnoreError();
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
return RetryingUtils::CallWithRetries(
[this, &data]() { return base_file_->Append(data); },
initial_delay_microseconds_);
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index ec2c470db7..5910fef1d2 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -72,7 +72,7 @@ class MockRandomAccessFile : public RandomAccessFile {
class MockWritableFile : public WritableFile {
public:
explicit MockWritableFile(const ExpectedCalls& calls) : calls_(calls) {}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
return calls_.ConsumeNextCall("Append");
}
Status Close() override { return calls_.ConsumeNextCall("Close"); }
diff --git a/tensorflow/core/platform/cord.h b/tensorflow/core/platform/cord.h
new file mode 100644
index 0000000000..7c5c6655be
--- /dev/null
+++ b/tensorflow/core/platform/cord.h
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_CORD_H_
+
+// Include appropriate platform-dependent implementations
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/cord.h"
+#else
+#include "tensorflow/core/platform/default/cord.h"
+#endif
+
+#endif // TENSORFLOW_CORE_PLATFORM_CORD_H_
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 3a012c23fd..37475feebe 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -3,64 +3,64 @@
# be separate to avoid cyclic references.
def tf_cuda_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_sycl_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_additional_plugin_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): [
- str(Label("//tensorflow/compiler/jit"))
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): [
+ str(Label("//tensorflow/compiler/jit")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_xla_deps_py():
- return []
+ return []
def tf_additional_grpc_deps_py():
- return []
+ return []
def tf_additional_license_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
+ "//conditions:default": [],
+ })
def tf_additional_verbs_deps():
- return select({
- str(Label("//tensorflow:with_verbs_support")): [
- str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
- str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_verbs_support")): [
+ str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
+ str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_mpi_deps():
- return select({
- str(Label("//tensorflow:with_mpi_support")): [
- str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_mpi_support")): [
+ str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_gdr_deps():
- return select({
- str(Label("//tensorflow:with_gdr_support")): [
- str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_gdr_support")): [
+ str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
-def if_static(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:framework_shared_object")): otherwise,
- "//conditions:default": extra_deps,
- })
+def if_static(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:framework_shared_object")): otherwise,
+ "//conditions:default": extra_deps,
+ })
-def if_dynamic_kernels(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
- "//conditions:default": otherwise,
- })
+def if_dynamic_kernels(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
+ "//conditions:default": otherwise,
+ })
diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h
new file mode 100644
index 0000000000..5823374d1a
--- /dev/null
+++ b/tensorflow/core/platform/default/cord.h
@@ -0,0 +1,21 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+
+// TODO(ebrevdo): Fill this in.
+
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index 0389149469..83c65dbfa9 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -321,7 +321,12 @@ class DeviceTracerImpl : public DeviceTracer,
return nullptr;
}
- bool IsEnabled(bool is_expensive) const override {
+ bool IsEnabledForAnnotations() const override {
+ // We are always enabled for 'Annotations'.
+ return true;
+ }
+
+ bool IsEnabledForActivities(bool is_expensive) const override {
// We don't do anything with 'Activities' so we are never 'enabled'.
return false;
}
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 305a9a682f..2e32abdffb 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/null_file_system.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -345,7 +346,13 @@ TEST_F(DefaultEnvTest, LocalTempFilename) {
// Write something to the temporary file.
std::unique_ptr<WritableFile> file_to_write;
TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write));
+#if defined(PLATFORM_GOOGLE)
+ TF_CHECK_OK(file_to_write->Append("Nu"));
+ TF_CHECK_OK(file_to_write->Append(absl::Cord("ll")));
+#else
+ // TODO(ebrevdo): Remove this version.
TF_CHECK_OK(file_to_write->Append("Null"));
+#endif
TF_CHECK_OK(file_to_write->Close());
TF_CHECK_OK(env->FileExists(filename));
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 077b1d79cf..156af6cdea 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/file_statistics.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/platform.h"
@@ -252,7 +253,15 @@ class WritableFile {
virtual ~WritableFile();
/// \brief Append 'data' to the file.
- virtual Status Append(const StringPiece& data) = 0;
+ virtual Status Append(StringPiece data) = 0;
+
+ // TODO(ebrevdo): Remove this ifdef when absl is updated.
+#if defined(PLATFORM_GOOGLE)
+ // \brief Append 'data' to the file.
+ virtual Status Append(const absl::Cord& cord) {
+ return errors::Unimplemented("Append(absl::Cord) is not implemented");
+ }
+#endif
/// \brief Close the file.
///
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 8cdb08f51b..eb35531e9f 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -282,7 +282,7 @@ class HDFSWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
if (hdfs_->hdfsWrite(fs_, file_, data.data(),
static_cast<tSize>(data.size())) == -1) {
return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc
index 47bfa020ce..c7afab9583 100644
--- a/tensorflow/core/platform/posix/posix_file_system.cc
+++ b/tensorflow/core/platform/posix/posix_file_system.cc
@@ -91,7 +91,7 @@ class PosixWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
size_t r = fwrite(data.data(), 1, data.size(), file_);
if (r != data.size()) {
return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index ce0f6cd741..e0b8e37745 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -211,7 +211,7 @@ class S3WritableFile : public WritableFile {
std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
std::ios_base::out)) {}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
if (!outfile_) {
return errors::FailedPrecondition(
"The internal temporary file is not writable.");
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index 9974bbbb4e..aefbe64425 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -155,9 +155,12 @@ class TraceCollector {
StringPiece name_part1, StringPiece name_part2,
bool is_expensive) const = 0;
+ // Returns true if this annotation tracing is enabled for any op.
+ virtual bool IsEnabledForAnnotations() const = 0;
+
// Returns true if this activity handle tracking is enabled for an op of the
// given expensiveness.
- virtual bool IsEnabled(bool is_expensive) const = 0;
+ virtual bool IsEnabledForActivities(bool is_expensive) const = 0;
protected:
static string ConcatenateNames(StringPiece first, StringPiece second);
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index 9079a5ccaa..6cf79634d7 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -150,7 +150,7 @@ class WindowsWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
DWORD bytes_written = 0;
DWORD data_size = static_cast<DWORD>(data.size());
BOOL write_result =
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 625d5649e6..85cd02350a 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -68,7 +68,7 @@ message GPUOptions {
// after the process starts. Users are required to use vendor
// specific mechanisms (e.g., CUDA_VISIBLE_DEVICES) to control the
// physical to visible device mapping prior to invoking TensorFlow.
- // 2. In the code, the ids in this list are also called "CUDA GPU id"s,
+ // 2. In the code, the ids in this list are also called "platform GPU id"s,
// and the 'virtual' ids of GPU devices (i.e. the ids in the device
// name "/device:GPU:<id>") are also called "TF GPU id"s. Please
// refer to third_party/tensorflow/core/common_runtime/gpu/gpu_id.h
diff --git a/tensorflow/core/protobuf/replay_log.proto b/tensorflow/core/protobuf/replay_log.proto
new file mode 100644
index 0000000000..7644314fc9
--- /dev/null
+++ b/tensorflow/core/protobuf/replay_log.proto
@@ -0,0 +1,47 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+package tensorflow;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/protobuf/cluster.proto";
+import "tensorflow/core/protobuf/master.proto";
+
+// Records the creation of a new replay session. We record the device listing
+// here to capture the state of the cluster.
+message NewReplaySession {
+ ListDevicesResponse devices = 1;
+ string session_handle = 2;
+}
+
+message ReplayOp {
+ double start_time_us = 31;
+ double end_time_us = 32;
+
+ oneof op {
+ CreateSessionRequest create_session = 1;
+ ExtendSessionRequest extend_session = 2;
+ PartialRunSetupRequest partial_run_setup = 3;
+ RunStepRequest run_step = 4;
+ CloseSessionRequest close_session = 5;
+ ListDevicesRequest list_devices = 6;
+ ResetRequest reset_request = 7;
+ MakeCallableRequest make_callable = 8;
+ RunCallableRequest run_callable = 9;
+ ReleaseCallableRequest release_callable = 10;
+ NewReplaySession new_replay_session = 11;
+ }
+
+ oneof response {
+ CreateSessionResponse create_session_response = 21;
+ ExtendSessionResponse extend_session_response = 22;
+ PartialRunSetupResponse partial_run_setup_response = 23;
+ RunStepResponse run_step_response = 24;
+ CloseSessionResponse close_session_response = 25;
+ ListDevicesResponse list_devices_response = 26;
+ ResetResponse reset_request_response = 27;
+ MakeCallableResponse make_callable_response = 28;
+ RunCallableResponse run_callable_response = 29;
+ ReleaseCallableResponse release_callable_response = 30;
+ }
+}
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 07f984ceea..bb8f88336d 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -75,6 +75,8 @@ message RewriterConfig {
// Try to allocate some independent Op outputs contiguously in order to
// merge or eliminate downstream Ops (off by default).
Toggle scoped_allocator_optimization = 15;
+ // Force small ops onto the CPU (default is ON).
+ Toggle pin_to_host_optimization = 18;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 4129c93af5..b043a69431 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 10
+#define TF_MINOR_VERSION 11
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc1"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 540adb58d4..f6f0408ccc 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -93,11 +93,11 @@ __device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
}
namespace cuda_helper {
-template <typename IntType>
-__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
- IntType* orig = first;
- IntType* it = nullptr;
- IntType step = 0;
+template <typename T, typename OutType = int32>
+__device__ OutType upper_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
while (count > 0) {
it = first;
step = count / 2;
@@ -112,6 +112,27 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
return first - orig;
}
+
+template <typename T, typename OutType = int32>
+__device__ OutType lower_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
+ while (count > 0) {
+ it = first;
+ step = count / 2;
+ it += step;
+ if (*it < val) {
+ first = ++it;
+ count -= step + 1;
+ } else {
+ count = step;
+ }
+ }
+
+ return first - orig;
+}
+
} // namespace cuda_helper
} // namespace tensorflow
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 680211edff..cf7ffd8149 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -34,9 +34,8 @@ limitations under the License.
#endif
#ifdef INTEL_MKL_ML_ONLY
-// Using pragma message since #warning doesn't work with all compilers
-#pragma message("Compiling for INTEL MKL ML only will be deprecated soon.")
-#pragma message("Please use MKL DNN (the default option for --config=mkl)")
+#error \
+ "Compiling for INTEL MKL ML only is no longer supported.Please use MKL DNN (the default option for --config=mkl)"
#endif
#ifdef INTEL_MKL_ML_ONLY
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
index 204b933051..546b0a833c 100644
--- a/tensorflow/core/util/sparse/group_iterator.cc
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -21,8 +21,8 @@ namespace sparse {
void GroupIterable::IteratorStep::UpdateEndOfGroup() {
++next_loc_;
- int64 N = iter_->ix_.dim_size(0);
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
+ const int64 N = ix_t.dimension(0);
while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
++next_loc_;
}
@@ -54,7 +54,7 @@ GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
std::vector<int64> Group::group() const {
std::vector<int64> g;
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
for (const int d : iter_->group_dims_) {
g.push_back(ix_t(loc_, d));
}
@@ -62,8 +62,8 @@ std::vector<int64> Group::group() const {
}
TTypes<int64>::UnalignedConstMatrix Group::indices() const {
- return TTypes<int64>::UnalignedConstMatrix(
- &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+ return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)),
+ next_loc_ - loc_, iter_->dims_);
}
} // namespace sparse
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index 3fa8cb6116..14610c61d9 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -79,6 +79,7 @@ class GroupIterable {
GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
: ix_(ix),
+ ix_matrix_(ix_.matrix<int64>()),
vals_(vals),
dims_(dims),
group_dims_(group_dims.begin(), group_dims.end()) {}
@@ -127,7 +128,8 @@ class GroupIterable {
private:
friend class Group;
- Tensor ix_;
+ const Tensor ix_;
+ const TTypes<int64>::ConstMatrix ix_matrix_;
Tensor vals_;
const int dims_;
const gtl::InlinedVector<int64, 8> group_dims_;
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 0f04b65f60..b9ca8ab395 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/base/macros.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -95,21 +96,21 @@ class SparseTensor {
SparseTensor() : dims_(0) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
: SparseTensor(ix, vals, TensorShapeToVector(shape),
UndefinedOrder(TensorShapeToVector(shape))) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
: SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
const VarDimArray order)
: SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
const VarDimArray order)
: ix_(ix),
@@ -237,9 +238,10 @@ class SparseTensor {
static Status Split(const SparseTensor& tensor, const int split_dim,
const int num_split, std::vector<SparseTensor>* result);
- // DEPRECATED: use the form of Split() that takes an output pointer and
- // returns a status instead.
template <typename T>
+ ABSL_DEPRECATED(
+ "Use the form of Split() that takes an output pointer and returns a "
+ "status instead.")
static std::vector<SparseTensor> Split(const SparseTensor& tensor,
const int split_dim,
const int num_split,
diff --git a/tensorflow/core/util/tensor_bundle/naming.h b/tensorflow/core/util/tensor_bundle/naming.h
index 6539d565e2..7b101971a8 100644
--- a/tensorflow/core/util/tensor_bundle/naming.h
+++ b/tensorflow/core/util/tensor_bundle/naming.h
@@ -35,6 +35,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index f4bd2950e9..74f0713a61 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -50,6 +50,8 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
max_parallelism);
}
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size.
void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work,
const Runner& runner, int max_parallelism) {
cost_per_unit = std::max(int64{1}, cost_per_unit);
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index b12c31c1ae..9db85a54c6 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -23,6 +23,9 @@ limitations under the License.
namespace tensorflow {
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size. Use this function only if you want to
+// manually cap parallelism.
// Shards the "total" unit of work assuming each unit of work having
// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
// total - 1. Each shard contains 1 or more units of work and the
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/examples/autograph/integration_tests/BUILD
index 3630b41fc8..3630b41fc8 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD
+++ b/tensorflow/examples/autograph/integration_tests/BUILD
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py b/tensorflow/examples/autograph/integration_tests/errors_test.py
index 04a968be10..9c10dad9aa 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/errors_test.py
+++ b/tensorflow/examples/autograph/integration_tests/errors_test.py
@@ -20,21 +20,18 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.contrib import autograph as ag
-from tensorflow.python.util import tf_inspect
+from tensorflow.python import autograph as ag
class ErrorsTest(tf.test.TestCase):
def test_graph_construction_error_rewriting_call_tree(self):
- def innermost(x):
- if x > 0:
- return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
- return tf.zeros((2, 3))
+ def test_fn():
+ return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
def inner_caller():
- return innermost(1.0)
+ return test_fn()
def caller():
return inner_caller()
@@ -45,23 +42,21 @@ class ErrorsTest(tf.test.TestCase):
expected = error.exception
custom_traceback = expected.custom_traceback
found_correct_filename = False
- num_innermost_names = 0
+ num_test_fn_names = 0
num_inner_caller_names = 0
num_caller_names = 0
- ag_output_filename = tf_inspect.getsourcefile(graph)
for frame in custom_traceback:
filename, _, fn_name, _ = frame
- self.assertFalse('control_flow_ops.py' in filename)
- self.assertFalse(ag_output_filename in filename)
+ self.assertFalse('/tmp/' in filename)
found_correct_filename |= __file__ in filename
self.assertNotEqual('tf__test_fn', fn_name)
- num_innermost_names += int('innermost' == fn_name)
+ num_test_fn_names += int('test_fn' == fn_name)
self.assertNotEqual('tf__inner_caller', fn_name)
num_inner_caller_names += int('inner_caller' == fn_name)
self.assertNotEqual('tf__caller', fn_name)
num_caller_names += int('caller' == fn_name)
self.assertTrue(found_correct_filename)
- self.assertEqual(num_innermost_names, 1)
+ self.assertEqual(num_test_fn_names, 1)
self.assertEqual(num_inner_caller_names, 1)
self.assertEqual(num_caller_names, 1)
@@ -106,19 +101,14 @@ class ErrorsTest(tf.test.TestCase):
found_correct_filename = False
num_test_fn_frames = 0
num_g_frames = 0
- ag_output_filename = tf_inspect.getsourcefile(compiled_fn)
for frame in custom_traceback:
filename, _, fn_name, source_code = frame
- self.assertFalse(ag_output_filename in filename)
- self.assertFalse('control_flow_ops.py' in filename)
+ self.assertFalse('/tmp/' in filename)
+ self.assertFalse('control_flow.py' in filename)
self.assertFalse('ag__.' in fn_name)
- self.assertFalse('tf__g' in fn_name)
- self.assertFalse('tf__test_fn' in fn_name)
found_correct_filename |= __file__ in filename
num_test_fn_frames += int('test_fn' == fn_name and
'return g(x, 10)' in source_code)
- # This makes sure that the code is correctly rewritten from "x_1 //= 0" to
- # "x //= 0".
num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
self.assertTrue(found_correct_filename)
self.assertEqual(num_test_fn_frames, 1)
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/examples/autograph/integration_tests/keras_test.py
index 7e7ef5a3e2..dca7c07b47 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
+++ b/tensorflow/examples/autograph/integration_tests/keras_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.contrib import autograph
+from tensorflow.python import autograph
class MinimalKeras(tf.keras.Model):
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py b/tensorflow/examples/autograph/integration_tests/list_literals_test.py
index 904246afb7..917f5ff9d8 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/list_literals_test.py
+++ b/tensorflow/examples/autograph/integration_tests/list_literals_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.contrib import autograph as ag
+from tensorflow.python import autograph as ag
def list_used_as_tuple():
diff --git a/tensorflow/examples/learn/text_classification_character_cnn.py b/tensorflow/examples/learn/text_classification_character_cnn.py
index afda170e2a..b8506fa8a4 100644
--- a/tensorflow/examples/learn/text_classification_character_cnn.py
+++ b/tensorflow/examples/learn/text_classification_character_cnn.py
@@ -74,7 +74,7 @@ def char_cnn_model(features, labels, mode):
kernel_size=FILTER_SHAPE2,
padding='VALID')
# Max across each filter to get useful features for classification.
- pool2 = tf.squeeze(tf.reduce_max(conv2, 1), squeeze_dims=[1])
+ pool2 = tf.squeeze(tf.reduce_max(conv2, 1), axis=[1])
# Apply regular WX + B and classification.
logits = tf.layers.dense(pool2, MAX_LABEL, activation=None)
diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py
index c8de6c2152..0c7ca9bc01 100644
--- a/tensorflow/examples/speech_commands/freeze_test.py
+++ b/tensorflow/examples/speech_commands/freeze_test.py
@@ -25,7 +25,7 @@ from tensorflow.python.platform import test
class FreezeTest(test.TestCase):
def testCreateInferenceGraphWithMfcc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
@@ -44,7 +44,7 @@ class FreezeTest(test.TestCase):
self.assertEqual(1, ops.count('Mfcc'))
def testCreateInferenceGraphWithoutMfcc(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
@@ -63,7 +63,7 @@ class FreezeTest(test.TestCase):
self.assertEqual(0, ops.count('Mfcc'))
def testFeatureBinCount(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
diff --git a/tensorflow/examples/speech_commands/input_data_test.py b/tensorflow/examples/speech_commands/input_data_test.py
index 2e551be9a2..aa4e807779 100644
--- a/tensorflow/examples/speech_commands/input_data_test.py
+++ b/tensorflow/examples/speech_commands/input_data_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class InputDataTest(test.TestCase):
def _getWavData(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sample_data = tf.zeros([32000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
@@ -75,7 +75,7 @@ class InputDataTest(test.TestCase):
self._saveTestWavFile(file_path, wav_data)
model_settings = models.prepare_model_settings(
4, 16000, 1000, window_length_ms, 20, 40, preprocess)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
audio_processor = input_data.AudioProcessor(
"", wav_dir, 10, 10, ["a", "b"], 10, 10, model_settings, tmp_dir)
result_data, result_labels = audio_processor.get_data(
diff --git a/tensorflow/examples/speech_commands/label_wav_test.py b/tensorflow/examples/speech_commands/label_wav_test.py
index 80ca774706..f0af2a4798 100644
--- a/tensorflow/examples/speech_commands/label_wav_test.py
+++ b/tensorflow/examples/speech_commands/label_wav_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class LabelWavTest(test.TestCase):
def _getWavData(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sample_data = tf.zeros([1000, 2])
wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
wav_data = sess.run(wav_encoder)
diff --git a/tensorflow/examples/speech_commands/models_test.py b/tensorflow/examples/speech_commands/models_test.py
index 0c373967ed..04478c0962 100644
--- a/tensorflow/examples/speech_commands/models_test.py
+++ b/tensorflow/examples/speech_commands/models_test.py
@@ -49,7 +49,7 @@ class ModelsTest(test.TestCase):
def testCreateModelConvTraining(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(fingerprint_input,
model_settings, "conv", True)
@@ -60,7 +60,7 @@ class ModelsTest(test.TestCase):
def testCreateModelConvInference(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits = models.create_model(fingerprint_input, model_settings, "conv",
False)
@@ -69,7 +69,7 @@ class ModelsTest(test.TestCase):
def testCreateModelLowLatencyConvTraining(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "low_latency_conv", True)
@@ -80,7 +80,7 @@ class ModelsTest(test.TestCase):
def testCreateModelFullyConnectedTraining(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "single_fc", True)
@@ -91,7 +91,7 @@ class ModelsTest(test.TestCase):
def testCreateModelBadArchitecture(self):
model_settings = self._modelSettings()
- with self.test_session():
+ with self.cached_session():
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
with self.assertRaises(Exception) as e:
models.create_model(fingerprint_input, model_settings,
@@ -100,7 +100,7 @@ class ModelsTest(test.TestCase):
def testCreateModelTinyConvTraining(self):
model_settings = self._modelSettings()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
fingerprint_input = tf.zeros([1, model_settings["fingerprint_size"]])
logits, dropout_prob = models.create_model(
fingerprint_input, model_settings, "tiny_conv", True)
diff --git a/tensorflow/examples/tutorials/mnist/BUILD b/tensorflow/examples/tutorials/mnist/BUILD
index d4070fdd1e..99da44d6d5 100644
--- a/tensorflow/examples/tutorials/mnist/BUILD
+++ b/tensorflow/examples/tutorials/mnist/BUILD
@@ -84,6 +84,18 @@ py_binary(
)
py_binary(
+ name = "mnist_softmax_xla",
+ srcs = [
+ "mnist_softmax_xla.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":input_data",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_binary(
name = "mnist_deep",
srcs = [
"mnist_deep.py",
diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md
index 288a32530a..3989f9b25a 100644
--- a/tensorflow/go/README.md
+++ b/tensorflow/go/README.md
@@ -10,7 +10,7 @@ Construct and execute TensorFlow graphs in Go.
## Quickstart
-Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/install_go)
+Refer to [Installing TensorFlow for Go](https://www.tensorflow.org/install/lang_go)
## Building the TensorFlow C library from source
@@ -23,9 +23,7 @@ from source.
- [bazel](https://www.bazel.build/versions/master/docs/install.html)
- Environment to build TensorFlow from source code
- ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux)
- or [OS
- X](https://www.tensorflow.org/install/install_sources#PrepareMac)).
+ ([Linux of macOS](https://www.tensorflow.org/install/source)).
If you don't need GPU support, then try the following:
```sh
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index e755c37039..9dd487e73b 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -332,7 +332,7 @@ func FakeQuantWithMinMaxArgs(scope *Scope, inputs tf.Output, optional ...FakeQua
// Creates a new tensor by applying sparse `updates` to individual values or
// slices within a tensor (initially zero for numeric, empty for string) of
// the given `shape` according to indices. This operator is the inverse of the
-// @{tf.gather_nd} operator which extracts values or slices from a given tensor.
+// `tf.gather_nd` operator which extracts values or slices from a given tensor.
//
// If `indices` contains duplicates, then their updates are accumulated (summed).
//
@@ -1473,7 +1473,7 @@ type StridedSliceAttr func(optionalAttr)
//
// value: a bitmask where a bit i being 1 means to ignore the begin
// value and instead use the largest interval possible. At runtime
-// begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
+// begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or
// `[-1, n-1]` if `stride[i] < 0`
// If not specified, defaults to 0
func StridedSliceBeginMask(value int64) StridedSliceAttr {
@@ -1856,6 +1856,32 @@ func ReverseSequence(scope *Scope, input tf.Output, seq_lengths tf.Output, seq_d
return op.Output(0)
}
+// Ensures that the tensor's shape matches the expected shape.
+//
+// Raises an error if the input tensor's shape does not match the specified shape.
+// Returns the input tensor otherwise.
+//
+// Arguments:
+// input: A tensor, whose shape is to be validated.
+// shape: The expected (possibly partially specified) shape of the input tensor.
+//
+// Returns A tensor with the same shape and contents as the input tensor or value.
+func EnsureShape(scope *Scope, input tf.Output, shape tf.Shape) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"shape": shape}
+ opspec := tf.OpSpec{
+ Type: "EnsureShape",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// UniqueWithCountsV2Attr is an optional argument to UniqueWithCountsV2.
type UniqueWithCountsV2Attr func(optionalAttr)
@@ -2259,7 +2285,7 @@ func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Ou
//
// output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]]
//
-// Whereas in @{tf.gather} `indices` defines slices into the first
+// Whereas in `tf.gather` `indices` defines slices into the first
// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
// first `N` dimensions of `params`, where `N = indices.shape[-1]`.
//
@@ -2356,6 +2382,8 @@ func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Ou
// output = [['b0', 'b1'], ['d0', 'c1']]
// ```
//
+// See also `tf.gather` and `tf.batch_gather`.
+//
// Arguments:
// params: The tensor from which to gather values.
// indices: Index tensor.
@@ -2433,6 +2461,64 @@ func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...Gathe
return op.Output(0)
}
+// LowerBoundAttr is an optional argument to LowerBound.
+type LowerBoundAttr func(optionalAttr)
+
+// LowerBoundOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func LowerBoundOutType(value tf.DataType) LowerBoundAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Applies lower_bound(sorted_search_values, values) along each row.
+//
+// Each set of rows with the same index in (sorted_inputs, values) is treated
+// independently. The resulting row is the equivalent of calling
+// `np.searchsorted(sorted_inputs, values, side='left')`.
+//
+// The result is not a global index to the entire
+// `Tensor`, but rather just the index in the last dimension.
+//
+// A 2-D example:
+// sorted_sequence = [[0, 3, 9, 9, 10],
+// [1, 2, 3, 4, 5]]
+// values = [[2, 4, 9],
+// [0, 2, 6]]
+//
+// result = LowerBound(sorted_sequence, values)
+//
+// result == [[1, 2, 2],
+// [0, 1, 5]]
+//
+// Arguments:
+// sorted_inputs: 2-D Tensor where each row is ordered.
+// values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+// the values that will be searched for in `sorted_search_values`.
+//
+// Returns A `Tensor` with the same shape as `values`. It contains the first scalar index
+// into the last dimension where values can be inserted without changing the
+// ordered property.
+func LowerBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...LowerBoundAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "LowerBound",
+ Input: []tf.Input{
+ sorted_inputs, values,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Creates a tensor filled with a scalar value.
//
// This operation creates a tensor of shape `dims` and fills it with `value`.
@@ -2445,6 +2531,16 @@ func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...Gathe
// [9, 9, 9]]
// ```
//
+// `tf.fill` differs from `tf.constant` in a few ways:
+//
+// * `tf.fill` only supports scalar contents, whereas `tf.constant` supports
+// Tensor values.
+// * `tf.fill` creates an Op in the computation graph that constructs the actual
+// Tensor value at runtime. This is in contrast to `tf.constant` which embeds
+// the entire Tensor into the graph with a `Const` node.
+// * Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
+// based on other runtime Tensors, unlike `tf.constant`.
+//
// Arguments:
// dims: 1-D. Represents the shape of the output tensor.
// value: 0-D (scalar). Value to fill the returned tensor.
@@ -2858,6 +2954,25 @@ func GuaranteeConst(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
+// Returns a constant tensor on the host. Only for writing C++ tests.
+//
+// Arguments:
+// value: Attr `value` is the tensor to return.
+//
+func HostConst(scope *Scope, value tf.Tensor, dtype tf.DataType) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"value": value, "dtype": dtype}
+ opspec := tf.OpSpec{
+ Type: "HostConst",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Splits a tensor into `num_split` tensors along one dimension.
//
// Arguments:
@@ -3377,6 +3492,204 @@ func PopulationCount(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Bucketize each feature based on bucket boundaries.
+//
+// An op that returns a list of float tensors, where each tensor represents the
+// bucketized values for a single feature.
+//
+// Arguments:
+// float_values: float; List of Rank 2 Tensor each containing float values for a single feature.
+// bucket_boundaries: float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+// feature.
+//
+// Returns int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+func BoostedTreesBucketize(scope *Scope, float_values []tf.Output, bucket_boundaries []tf.Output) (buckets []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesBucketize",
+ Input: []tf.Input{
+ tf.OutputList(float_values), tf.OutputList(bucket_boundaries),
+ },
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if buckets, idx, err = makeOutputList(op, idx, "buckets"); err != nil {
+ scope.UpdateErr("BoostedTreesBucketize", err)
+ return
+ }
+ return buckets
+}
+
+// BoostedTreesQuantileStreamResourceFlushAttr is an optional argument to BoostedTreesQuantileStreamResourceFlush.
+type BoostedTreesQuantileStreamResourceFlushAttr func(optionalAttr)
+
+// BoostedTreesQuantileStreamResourceFlushGenerateQuantiles sets the optional generate_quantiles attribute to value.
+//
+// value: bool; If True, the output will be the num_quantiles for each stream where the ith
+// entry is the ith quantile of the input with an approximation error of epsilon.
+// Duplicate values may be present.
+// If False, the output will be the points in the histogram that we got which roughly
+// translates to 1/epsilon boundaries and without any duplicates.
+// Default to False.
+// If not specified, defaults to false
+func BoostedTreesQuantileStreamResourceFlushGenerateQuantiles(value bool) BoostedTreesQuantileStreamResourceFlushAttr {
+ return func(m optionalAttr) {
+ m["generate_quantiles"] = value
+ }
+}
+
+// Flush the summaries for a quantile stream resource.
+//
+// An op that flushes the summaries for a quantile stream resource.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+// num_buckets: int; approximate number of buckets unless using generate_quantiles.
+//
+// Returns the created operation.
+func BoostedTreesQuantileStreamResourceFlush(scope *Scope, quantile_stream_resource_handle tf.Output, num_buckets tf.Output, optional ...BoostedTreesQuantileStreamResourceFlushAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceFlush",
+ Input: []tf.Input{
+ quantile_stream_resource_handle, num_buckets,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Add the quantile summaries to each quantile stream resource.
+//
+// An op that adds a list of quantile summaries to a quantile stream resource. Each
+// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+// for a single feature.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+// summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature.
+//
+// Returns the created operation.
+func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceAddSummaries",
+ Input: []tf.Input{
+ quantile_stream_resource_handle, tf.OutputList(summaries),
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Makes the summary of quantiles for the batch.
+//
+// An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+//
+// Arguments:
+// float_values: float; List of Rank 2 Tensors each containing values for a single feature.
+// example_weights: float; Rank 1 Tensor with weights per instance.
+// epsilon: float; The required maximum approximation error.
+//
+// Returns float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+// min_rank, max_rank) of a single feature.
+func BoostedTreesMakeQuantileSummaries(scope *Scope, float_values []tf.Output, example_weights tf.Output, epsilon tf.Output) (summaries []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesMakeQuantileSummaries",
+ Input: []tf.Input{
+ tf.OutputList(float_values), example_weights, epsilon,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if summaries, idx, err = makeOutputList(op, idx, "summaries"); err != nil {
+ scope.UpdateErr("BoostedTreesMakeQuantileSummaries", err)
+ return
+ }
+ return summaries
+}
+
+// BoostedTreesCreateQuantileStreamResourceAttr is an optional argument to BoostedTreesCreateQuantileStreamResource.
+type BoostedTreesCreateQuantileStreamResourceAttr func(optionalAttr)
+
+// BoostedTreesCreateQuantileStreamResourceMaxElements sets the optional max_elements attribute to value.
+//
+// value: int; The maximum number of data points that can be fed to the stream.
+// If not specified, defaults to 1099511627776
+func BoostedTreesCreateQuantileStreamResourceMaxElements(value int64) BoostedTreesCreateQuantileStreamResourceAttr {
+ return func(m optionalAttr) {
+ m["max_elements"] = value
+ }
+}
+
+// Create the Resource for Quantile Streams.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource; Handle to quantile stream resource.
+// epsilon: float; The required approximation error of the stream resource.
+// num_streams: int; The number of streams managed by the resource that shares the same epsilon.
+//
+// Returns the created operation.
+func BoostedTreesCreateQuantileStreamResource(scope *Scope, quantile_stream_resource_handle tf.Output, epsilon tf.Output, num_streams tf.Output, optional ...BoostedTreesCreateQuantileStreamResourceAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesCreateQuantileStreamResource",
+ Input: []tf.Input{
+ quantile_stream_resource_handle, epsilon, num_streams,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Checks whether a quantile stream has been initialized.
+//
+// An Op that checks if quantile stream resource is initialized.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource; The reference to quantile stream resource handle.
+//
+// Returns bool; True if the resource is initialized, False otherwise.
+func IsBoostedTreesQuantileStreamResourceInitialized(scope *Scope, quantile_stream_resource_handle tf.Output) (is_initialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "IsBoostedTreesQuantileStreamResourceInitialized",
+ Input: []tf.Input{
+ quantile_stream_resource_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Calculates the prior from the training data (the bias) and fills in the first node with the logits' prior. Returns a boolean indicating whether to continue centering.
//
// Arguments:
@@ -3456,113 +3769,121 @@ func BoostedTreesSerializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output)
return op.Output(0), op.Output(1)
}
-// Computes the sum along sparse segments of a tensor.
-//
-// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
-// misisng, the `output` tensor at that position will be zeroed.
+// Debugging/model interpretability outputs for each example.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// It traverses all the trees and computes debug metrics for individual examples,
+// such as getting split feature ids and logits after each split along the decision
+// path used to compute directional feature contributions.
//
-// For example:
+// Arguments:
//
-// ```python
-// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+// bucketized_features: A list of rank 1 Tensors containing bucket id for each
+// feature.
+// logits_dimension: scalar, dimension of the logits, to be used for constructing the protos in
+// examples_debug_outputs_serialized.
//
-// tf.sparse_segment_sum_with_num_segments(
-// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
-// # => [[0 0 0 0]
-// # [0 0 0 0]
-// # [0 0 0 0]]
+// Returns Output rank 1 Tensor containing a proto serialized as a string for each example.
+func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (examples_debug_outputs_serialized tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"logits_dimension": logits_dimension}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesExampleDebugOutputs",
+ Input: []tf.Input{
+ tree_ensemble_handle, tf.OutputList(bucketized_features),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Makes the summary of accumulated stats for the batch.
//
-// tf.sparse_segment_sum_with_num_segments(c,
-// tf.constant([0, 1]),
-// tf.constant([0, 2],
-// num_segments=4))
-// # => [[ 1 2 3 4]
-// # [ 0 0 0 0]
-// # [-1 -2 -3 -4]
-// # [ 0 0 0 0]]
-// ```
+// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
//
// Arguments:
+// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
+// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
+// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
+// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
+// max_splits: int; the maximum number of splits possible in the whole tree.
+// num_buckets: int; equals to the maximum possible value of bucketized feature.
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
-// num_segments: Should equal the number of distinct segment IDs.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
+func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
opspec := tf.OpSpec{
- Type: "SparseSegmentSumWithNumSegments",
+ Type: "BoostedTreesMakeStatsSummary",
Input: []tf.Input{
- data, indices, segment_ids, num_segments,
+ node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// PreventGradientAttr is an optional argument to PreventGradient.
-type PreventGradientAttr func(optionalAttr)
-
-// PreventGradientMessage sets the optional message attribute to value.
+// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
//
-// value: Will be printed in the error when anyone tries to differentiate
-// this operation.
-// If not specified, defaults to ""
-func PreventGradientMessage(value string) PreventGradientAttr {
- return func(m optionalAttr) {
- m["message"] = value
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+//
+// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
+// layer.
+func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
+ if scope.Err() != nil {
+ return
}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesGetEnsembleStates",
+ Input: []tf.Input{
+ tree_ensemble_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
}
-// An identity op that triggers an error if a gradient is requested.
-//
-// When executed in a graph, this op outputs its input tensor as-is.
-//
-// When building ops to compute gradients, the TensorFlow gradient system
-// will return an error when trying to lookup the gradient of this op,
-// because no gradient must ever be registered for this function. This
-// op exists to prevent subtle bugs from silently returning unimplemented
-// gradients in some corner cases.
+// Creates a tree ensemble model and returns a handle to it.
//
// Arguments:
-// input: any tensor.
+// tree_ensemble_handle: Handle to the tree ensemble resource to be created.
+// stamp_token: Token to use as the initial value of the resource stamp.
+// tree_ensemble_serialized: Serialized proto of the tree ensemble.
//
-// Returns the same input tensor.
-func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
+// Returns the created operation.
+func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "PreventGradient",
+ Type: "BoostedTreesCreateEnsemble",
Input: []tf.Input{
- input,
+ tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
},
- Attrs: attrs,
}
- op := scope.AddOperation(opspec)
- return op.Output(0)
+ return scope.AddOperation(opspec)
}
-// Computes asin of x element-wise.
-func Asin(scope *Scope, x tf.Output) (y tf.Output) {
+// Checks whether a tree ensemble has been initialized.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble resouce.
+//
+// Returns output boolean on whether it is initialized or not.
+func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Output) (is_initialized tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Asin",
+ Type: "IsBoostedTreesEnsembleInitialized",
Input: []tf.Input{
- x,
+ tree_ensemble_handle,
},
}
op := scope.AddOperation(opspec)
@@ -3571,8 +3892,9 @@ func Asin(scope *Scope, x tf.Output) (y tf.Output) {
// Computes the sum along sparse segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
// dimension, selecting a subset of dimension 0, specified by `indices`.
@@ -3638,28 +3960,32 @@ func Sinh(scope *Scope, x tf.Output) (y tf.Output) {
// Computes the minimum along segments of a tensor.
//
-// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// for an explanation of segments.
//
// This operator is similar to the unsorted segment sum operator found
// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
// Instead of computing the sum over segments, it computes the minimum such that:
//
-// \\(output_i = \min_j data_j\\) where min is over `j` such
-// that `segment_ids[j] == i`.
+// \\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+// that `segment_ids[j...] == i`.
//
// If the minimum is empty for a given segment ID `i`, it outputs the largest
// possible value for the specific numeric type,
// `output[i] = numeric_limits<T>::max()`.
//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.
//
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
func UnsortedSegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
@@ -3691,11 +4017,12 @@ func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
// Computes the sum along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
-// \\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+// \\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
// that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
// need not be sorted and need not cover all values in the full
// range of valid values.
@@ -4272,37 +4599,61 @@ func AddV2(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// NthElementAttr is an optional argument to NthElement.
-type NthElementAttr func(optionalAttr)
+// Computes exponential of x element-wise. \\(y = e^x\\).
+func Exp(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Exp",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
-// NthElementReverse sets the optional reverse attribute to value.
+// Returns an element-wise indication of the sign of a number.
//
-// value: When set to True, find the nth-largest value in the vector and vice
-// versa.
-// If not specified, defaults to false
-func NthElementReverse(value bool) NthElementAttr {
+// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`.
+//
+// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
+func Sign(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Sign",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ArgMinAttr is an optional argument to ArgMin.
+type ArgMinAttr func(optionalAttr)
+
+// ArgMinOutputType sets the optional output_type attribute to value.
+// If not specified, defaults to DT_INT64
+func ArgMinOutputType(value tf.DataType) ArgMinAttr {
return func(m optionalAttr) {
- m["reverse"] = value
+ m["output_type"] = value
}
}
-// Finds values of the `n`-th order statistic for the last dimension.
-//
-// If the input is a vector (rank-1), finds the entries which is the nth-smallest
-// value in the vector and outputs their values as scalar tensor.
-//
-// For matrices (resp. higher rank input), computes the entries which is the
-// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+// Returns the index with the smallest value across dimensions of a tensor.
//
-// values.shape = input.shape[:-1]
+// Note that in case of ties the identity of the return value is not guaranteed.
//
// Arguments:
-// input: 1-D or higher with last dimension at least `n+1`.
-// n: 0-D. Position of sorted vector to select along the last dimension (along
-// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
//
-// Returns The `n`-th order statistic along each last dimensional slice.
-func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
+// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`.
+// Describes which dimension of the input Tensor to reduce across. For vectors,
+// use dimension = 0.
+func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -4311,9 +4662,9 @@ func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthEleme
a(attrs)
}
opspec := tf.OpSpec{
- Type: "NthElement",
+ Type: "ArgMin",
Input: []tf.Input{
- input, n,
+ input, dimension,
},
Attrs: attrs,
}
@@ -4321,55 +4672,94 @@ func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthEleme
return op.Output(0)
}
-// Computes the maximum along segments of a tensor.
+// Convert the quantized 'input' tensor into a lower-precision 'output', using the
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// output range specified with 'requested_output_min' and 'requested_output_max'.
//
-// This operator is similar to the unsorted segment sum operator found
-// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
-// Instead of computing the sum over segments, it computes the maximum such that:
+// [input_min, input_max] are scalar floats that specify the range for the float
+// interpretation of the 'input' data. For example, if input_min is -1.0f and
+// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0
+// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f.
//
-// \\(output_i = \max_j data_j\\) where max is over `j` such
-// that `segment_ids[j] == i`.
+// Arguments:
//
-// If the maximum is empty for a given segment ID `i`, it outputs the smallest
-// possible value for the specific numeric type,
-// `output[i] = numeric_limits<T>::lowest()`.
+// input_min: The float value that the minimum quantized input value represents.
+// input_max: The float value that the maximum quantized input value represents.
+// requested_output_min: The float value that the minimum quantized output value represents.
+// requested_output_max: The float value that the maximum quantized output value represents.
+// out_type: The type of the output. Should be a lower bit depth than Tinput.
//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
-// </div>
+// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output.
+func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"out_type": out_type}
+ opspec := tf.OpSpec{
+ Type: "Requantize",
+ Input: []tf.Input{
+ input, input_min, input_max, requested_output_min, requested_output_max,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// PreventGradientAttr is an optional argument to PreventGradient.
+type PreventGradientAttr func(optionalAttr)
+
+// PreventGradientMessage sets the optional message attribute to value.
//
-// Arguments:
+// value: Will be printed in the error when anyone tries to differentiate
+// this operation.
+// If not specified, defaults to ""
+func PreventGradientMessage(value string) PreventGradientAttr {
+ return func(m optionalAttr) {
+ m["message"] = value
+ }
+}
+
+// An identity op that triggers an error if a gradient is requested.
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
+// When executed in a graph, this op outputs its input tensor as-is.
//
+// When building ops to compute gradients, the TensorFlow gradient system
+// will return an error when trying to lookup the gradient of this op,
+// because no gradient must ever be registered for this function. This
+// op exists to prevent subtle bugs from silently returning unimplemented
+// gradients in some corner cases.
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
-func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
+// Arguments:
+// input: any tensor.
+//
+// Returns the same input tensor.
+func PreventGradient(scope *Scope, input tf.Output, optional ...PreventGradientAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "UnsortedSegmentMax",
+ Type: "PreventGradient",
Input: []tf.Input{
- data, segment_ids, num_segments,
+ input,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// Computes exponential of x element-wise. \\(y = e^x\\).
-func Exp(scope *Scope, x tf.Output) (y tf.Output) {
+// Computes asin of x element-wise.
+func Asin(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Exp",
+ Type: "Asin",
Input: []tf.Input{
x,
},
@@ -4378,46 +4768,86 @@ func Exp(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
-// Returns an element-wise indication of the sign of a number.
+// Computes the maximum along segments of a tensor.
//
-// `y = sign(x) = -1` if `x < 0`; 0 if `x == 0`; 1 if `x > 0`.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
-// For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
-func Sign(scope *Scope, x tf.Output) (y tf.Output) {
+// This operator is similar to the unsorted segment sum operator found
+// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+// Instead of computing the sum over segments, it computes the maximum such that:
+//
+// \\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+// that `segment_ids[j...] == i`.
+//
+// If the maximum is empty for a given segment ID `i`, it outputs the smallest
+// possible value for the specific numeric type,
+// `output[i] = numeric_limits<T>::lowest()`.
+//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
+// </div>
+//
+// Arguments:
+//
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.END
+// }
+// out_arg {
+// name: "output"
+// description: <<END
+// Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
+//
+func UnsortedSegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Sign",
+ Type: "UnsortedSegmentMax",
Input: []tf.Input{
- x,
+ data, segment_ids, num_segments,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// ArgMinAttr is an optional argument to ArgMin.
-type ArgMinAttr func(optionalAttr)
+// NthElementAttr is an optional argument to NthElement.
+type NthElementAttr func(optionalAttr)
-// ArgMinOutputType sets the optional output_type attribute to value.
-// If not specified, defaults to DT_INT64
-func ArgMinOutputType(value tf.DataType) ArgMinAttr {
+// NthElementReverse sets the optional reverse attribute to value.
+//
+// value: When set to True, find the nth-largest value in the vector and vice
+// versa.
+// If not specified, defaults to false
+func NthElementReverse(value bool) NthElementAttr {
return func(m optionalAttr) {
- m["output_type"] = value
+ m["reverse"] = value
}
}
-// Returns the index with the smallest value across dimensions of a tensor.
+// Finds values of the `n`-th order statistic for the last dimension.
//
-// Note that in case of ties the identity of the return value is not guaranteed.
+// If the input is a vector (rank-1), finds the entries which is the nth-smallest
+// value in the vector and outputs their values as scalar tensor.
+//
+// For matrices (resp. higher rank input), computes the entries which is the
+// nth-smallest value in each row (resp. vector along the last dimension). Thus,
+//
+// values.shape = input.shape[:-1]
//
// Arguments:
+// input: 1-D or higher with last dimension at least `n+1`.
+// n: 0-D. Position of sorted vector to select along the last dimension (along
+// each row for matrices). Valid range of n is `[0, input.shape[:-1])`
//
-// dimension: int32 or int64, must be in the range `[-rank(input), rank(input))`.
-// Describes which dimension of the input Tensor to reduce across. For vectors,
-// use dimension = 0.
-func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) {
+// Returns The `n`-th order statistic along each last dimensional slice.
+func NthElement(scope *Scope, input tf.Output, n tf.Output, optional ...NthElementAttr) (values tf.Output) {
if scope.Err() != nil {
return
}
@@ -4426,9 +4856,9 @@ func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgM
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ArgMin",
+ Type: "NthElement",
Input: []tf.Input{
- input, dimension,
+ input, n,
},
Attrs: attrs,
}
@@ -4436,38 +4866,56 @@ func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgM
return op.Output(0)
}
-// Convert the quantized 'input' tensor into a lower-precision 'output', using the
+// Computes the sum along sparse segments of a tensor.
//
-// output range specified with 'requested_output_min' and 'requested_output_max'.
+// Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+// misisng, the `output` tensor at that position will be zeroed.
//
-// [input_min, input_max] are scalar floats that specify the range for the float
-// interpretation of the 'input' data. For example, if input_min is -1.0f and
-// input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0
-// value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// For example:
+//
+// ```python
+// c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+//
+// tf.sparse_segment_sum_with_num_segments(
+// c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+// # => [[0 0 0 0]
+// # [0 0 0 0]
+// # [0 0 0 0]]
+//
+// tf.sparse_segment_sum_with_num_segments(c,
+// tf.constant([0, 1]),
+// tf.constant([0, 2],
+// num_segments=4))
+// # => [[ 1 2 3 4]
+// # [ 0 0 0 0]
+// # [-1 -2 -3 -4]
+// # [ 0 0 0 0]]
+// ```
//
// Arguments:
//
-// input_min: The float value that the minimum quantized input value represents.
-// input_max: The float value that the maximum quantized input value represents.
-// requested_output_min: The float value that the minimum quantized output value represents.
-// requested_output_max: The float value that the maximum quantized output value represents.
-// out_type: The type of the output. Should be a lower bit depth than Tinput.
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// num_segments: Should equal the number of distinct segment IDs.
//
-// Returns The requested_output_min value is copied into this output.The requested_output_max value is copied into this output.
-func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf.Output, requested_output_min tf.Output, requested_output_max tf.Output, out_type tf.DataType) (output tf.Output, output_min tf.Output, output_max tf.Output) {
+// Returns Has same shape as data, except for dimension 0 which
+// has size `num_segments`.
+func SparseSegmentSumWithNumSegments(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"out_type": out_type}
opspec := tf.OpSpec{
- Type: "Requantize",
+ Type: "SparseSegmentSumWithNumSegments",
Input: []tf.Input{
- input, input_min, input_max, requested_output_min, requested_output_max,
+ data, indices, segment_ids, num_segments,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
+ return op.Output(0)
}
// Computes the determinant of one or more square matrices.
@@ -5195,6 +5643,47 @@ func Reciprocal(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features.
+//
+// Arguments:
+//
+//
+// dense_defaults: A dict mapping string keys to `Tensor`s.
+// The keys of the dict must match the dense_keys of the feature.
+// sparse_keys: A list of string keys in the examples features.
+// The results for these keys will be returned as `SparseTensor` objects.
+// dense_keys: A list of Ndense string Tensors (scalars).
+// The keys expected in the Examples features associated with dense values.
+// sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
+// Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+// and `tf.string` (`BytesList`) are supported.
+// dense_shapes: List of tuples with the same length as `dense_keys`.
+// The shape of the data for each dense feature referenced by `dense_keys`.
+// Required for any input tensors identified by `dense_keys`. Must be
+// either fully defined, or may contain an unknown first dimension.
+// An unknown first dimension means the feature is treated as having
+// a variable number of blocks, and the output shape along this dimension
+// is considered unknown at graph build time. Padding is applied for
+// minibatch elements smaller than the maximum number of blocks for the
+// given feature along this dimension.
+// output_types: The type list for the return values.
+// output_shapes: The list of shapes being produced.
+func ParseExampleDataset(scope *Scope, input_dataset tf.Output, num_parallel_calls tf.Output, dense_defaults []tf.Output, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes, "output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "ParseExampleDataset",
+ Input: []tf.Input{
+ input_dataset, num_parallel_calls, tf.OutputList(dense_defaults),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns a batched matrix tensor with new batched diagonal values.
//
// Given `input` and `diagonal`, this operation returns a tensor with the
@@ -5386,26 +5875,6 @@ func LogicalAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
return op.Output(0)
}
-// Checks whether a tree ensemble has been initialized.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble resouce.
-//
-// Returns output boolean on whether it is initialized or not.
-func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Output) (is_initialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "IsBoostedTreesEnsembleInitialized",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// CastAttr is an optional argument to Cast.
type CastAttr func(optionalAttr)
@@ -5589,6 +6058,44 @@ func QuantizedAvgPool(scope *Scope, input tf.Output, min_input tf.Output, max_in
return op.Output(0), op.Output(1), op.Output(2)
}
+// Extract `patches` from `input` and put them in the "depth" output
+// dimension. 3D extension of `extract_image_patches`.
+//
+// Arguments:
+// input: 5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`.
+// ksizes: The size of the sliding window for each dimension of `input`.
+// strides: 1-D of length 5. How far the centers of two consecutive patches are in
+// `input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`.
+// padding: The type of padding algorithm to use.
+//
+// We specify the size-related attributes as:
+//
+// ```python
+// ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
+// strides = [1, stride_planes, strides_rows, strides_cols, 1]
+// ```
+//
+// Returns 5-D Tensor with shape `[batch, out_planes, out_rows, out_cols,
+// ksize_planes * ksize_rows * ksize_cols * depth]` containing patches
+// with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized
+// in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols`
+// are the dimensions of the output patches.
+func ExtractVolumePatches(scope *Scope, input tf.Output, ksizes []int64, strides []int64, padding string) (patches tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"ksizes": ksizes, "strides": strides, "padding": padding}
+ opspec := tf.OpSpec{
+ Type: "ExtractVolumePatches",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// FractionalAvgPoolAttr is an optional argument to FractionalAvgPool.
type FractionalAvgPoolAttr func(optionalAttr)
@@ -6159,6 +6666,41 @@ func ResourceSparseApplyAdadelta(scope *Scope, var_ tf.Output, accum tf.Output,
return scope.AddOperation(opspec)
}
+// Gets next element for the provided shard number.
+//
+// Arguments:
+// multi_device_iterator: A MultiDeviceIterator resource.
+// shard_num: Integer representing which shard to fetch data for.
+// incarnation_id: Which incarnation of the MultiDeviceIterator is running.
+// output_types: The type list for the return values.
+// output_shapes: The list of shapes being produced.
+//
+// Returns Result of the get_next on the dataset.
+func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.Output, shard_num tf.Output, incarnation_id tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIteratorGetNextFromShard",
+ Input: []tf.Input{
+ multi_device_iterator, shard_num, incarnation_id,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
+ scope.UpdateErr("MultiDeviceIteratorGetNextFromShard", err)
+ return
+ }
+ return components
+}
+
// Computes rectified linear gradients for a Relu operation.
//
// Arguments:
@@ -6446,7 +6988,7 @@ func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset
return offset
}
-// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
+// Compute the lower regularized incomplete Gamma function `P(a, x)`.
//
// The lower regularized incomplete Gamma function is defined as:
//
@@ -7880,6 +8422,214 @@ func QueueDequeueV2(scope *Scope, handle tf.Output, component_types []tf.DataTyp
return components
}
+// ParseSequenceExampleAttr is an optional argument to ParseSequenceExample.
+type ParseSequenceExampleAttr func(optionalAttr)
+
+// ParseSequenceExampleNcontextSparse sets the optional Ncontext_sparse attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNcontextSparse(value int64) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["Ncontext_sparse"] = value
+ }
+}
+
+// ParseSequenceExampleNcontextDense sets the optional Ncontext_dense attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNcontextDense(value int64) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["Ncontext_dense"] = value
+ }
+}
+
+// ParseSequenceExampleNfeatureListSparse sets the optional Nfeature_list_sparse attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNfeatureListSparse(value int64) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["Nfeature_list_sparse"] = value
+ }
+}
+
+// ParseSequenceExampleNfeatureListDense sets the optional Nfeature_list_dense attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func ParseSequenceExampleNfeatureListDense(value int64) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["Nfeature_list_dense"] = value
+ }
+}
+
+// ParseSequenceExampleContextSparseTypes sets the optional context_sparse_types attribute to value.
+//
+// value: A list of Ncontext_sparse types; the data types of data in
+// each context Feature given in context_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleContextSparseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["context_sparse_types"] = value
+ }
+}
+
+// ParseSequenceExampleFeatureListDenseTypes sets the optional feature_list_dense_types attribute to value.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListDenseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_dense_types"] = value
+ }
+}
+
+// ParseSequenceExampleContextDenseShapes sets the optional context_dense_shapes attribute to value.
+//
+// value: A list of Ncontext_dense shapes; the shapes of data in
+// each context Feature given in context_dense_keys.
+// The number of elements in the Feature corresponding to context_dense_key[j]
+// must always equal context_dense_shapes[j].NumEntries().
+// The shape of context_dense_values[j] will match context_dense_shapes[j].
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleContextDenseShapes(value []tf.Shape) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["context_dense_shapes"] = value
+ }
+}
+
+// ParseSequenceExampleFeatureListSparseTypes sets the optional feature_list_sparse_types attribute to value.
+//
+// value: A list of Nfeature_list_sparse types; the data types
+// of data in each FeatureList given in feature_list_sparse_keys.
+// Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListSparseTypes(value []tf.DataType) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_sparse_types"] = value
+ }
+}
+
+// ParseSequenceExampleFeatureListDenseShapes sets the optional feature_list_dense_shapes attribute to value.
+//
+// value: A list of Nfeature_list_dense shapes; the shapes of
+// data in each FeatureList given in feature_list_dense_keys.
+// The shape of each Feature in the FeatureList corresponding to
+// feature_list_dense_key[j] must always equal
+// feature_list_dense_shapes[j].NumEntries().
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func ParseSequenceExampleFeatureListDenseShapes(value []tf.Shape) ParseSequenceExampleAttr {
+ return func(m optionalAttr) {
+ m["feature_list_dense_shapes"] = value
+ }
+}
+
+// Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors.
+//
+// Arguments:
+// serialized: A vector containing binary serialized SequenceExample protos.
+// debug_name: A vector containing the names of the serialized protos.
+// May contain, for example, table key (descriptive) name for the
+// corresponding serialized proto. This is purely useful for debugging
+// purposes, and the presence of values here has no effect on the output.
+// May also be an empty vector if no name is available.
+// context_dense_defaults: A list of Ncontext_dense Tensors (some may be empty).
+// context_dense_defaults[j] provides default values
+// when the SequenceExample's context map lacks context_dense_key[j].
+// If an empty Tensor is provided for context_dense_defaults[j],
+// then the Feature context_dense_keys[j] is required.
+// The input type is inferred from context_dense_defaults[j], even when it's
+// empty. If context_dense_defaults[j] is not empty, its shape must match
+// context_dense_shapes[j].
+// feature_list_dense_missing_assumed_empty: A vector listing the
+// FeatureList keys which may be missing from the SequenceExamples. If the
+// associated FeatureList is missing, it is treated as empty. By default,
+// any FeatureList not listed in this vector must exist in the SequenceExamples.
+// context_sparse_keys: A list of Ncontext_sparse string Tensors (scalars).
+// The keys expected in the Examples' features associated with context_sparse
+// values.
+// context_dense_keys: A list of Ncontext_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' context features associated with
+// dense values.
+// feature_list_sparse_keys: A list of Nfeature_list_sparse string Tensors
+// (scalars). The keys expected in the FeatureLists associated with sparse
+// values.
+// feature_list_dense_keys: A list of Nfeature_list_dense string Tensors (scalars).
+// The keys expected in the SequenceExamples' feature_lists associated
+// with lists of dense values.
+func ParseSequenceExample(scope *Scope, serialized tf.Output, debug_name tf.Output, context_dense_defaults []tf.Output, feature_list_dense_missing_assumed_empty []string, context_sparse_keys []string, context_dense_keys []string, feature_list_sparse_keys []string, feature_list_dense_keys []string, optional ...ParseSequenceExampleAttr) (context_sparse_indices []tf.Output, context_sparse_values []tf.Output, context_sparse_shapes []tf.Output, context_dense_values []tf.Output, feature_list_sparse_indices []tf.Output, feature_list_sparse_values []tf.Output, feature_list_sparse_shapes []tf.Output, feature_list_dense_values []tf.Output, feature_list_dense_lengths []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"feature_list_dense_missing_assumed_empty": feature_list_dense_missing_assumed_empty, "context_sparse_keys": context_sparse_keys, "context_dense_keys": context_dense_keys, "feature_list_sparse_keys": feature_list_sparse_keys, "feature_list_dense_keys": feature_list_dense_keys}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ParseSequenceExample",
+ Input: []tf.Input{
+ serialized, debug_name, tf.OutputList(context_dense_defaults),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if context_sparse_indices, idx, err = makeOutputList(op, idx, "context_sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if context_sparse_values, idx, err = makeOutputList(op, idx, "context_sparse_values"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if context_sparse_shapes, idx, err = makeOutputList(op, idx, "context_sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if context_dense_values, idx, err = makeOutputList(op, idx, "context_dense_values"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_indices, idx, err = makeOutputList(op, idx, "feature_list_sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_values, idx, err = makeOutputList(op, idx, "feature_list_sparse_values"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_sparse_shapes, idx, err = makeOutputList(op, idx, "feature_list_sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_dense_values, idx, err = makeOutputList(op, idx, "feature_list_dense_values"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ if feature_list_dense_lengths, idx, err = makeOutputList(op, idx, "feature_list_dense_lengths"); err != nil {
+ scope.UpdateErr("ParseSequenceExample", err)
+ return
+ }
+ return context_sparse_indices, context_sparse_values, context_sparse_shapes, context_dense_values, feature_list_sparse_indices, feature_list_sparse_values, feature_list_sparse_shapes, feature_list_dense_values, feature_list_dense_lengths
+}
+
// Computes the Gauss error function of `x` element-wise.
func Erf(scope *Scope, x tf.Output) (y tf.Output) {
if scope.Err() != nil {
@@ -8681,6 +9431,66 @@ func RandomStandardNormal(scope *Scope, shape tf.Output, dtype tf.DataType, opti
return op.Output(0)
}
+// RandomUniformIntAttr is an optional argument to RandomUniformInt.
+type RandomUniformIntAttr func(optionalAttr)
+
+// RandomUniformIntSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random integers from a uniform distribution.
+//
+// The generated values are uniform integers in the range `[minval, maxval)`.
+// The lower bound `minval` is included in the range, while the upper bound
+// `maxval` is excluded.
+//
+// The random integers are slightly biased unless `maxval - minval` is an exact
+// power of two. The bias is small for values of `maxval - minval` significantly
+// smaller than the range of the output (either `2^32` or `2^64`).
+//
+// Arguments:
+// shape: The shape of the output tensor.
+// minval: 0-D. Inclusive lower bound on the generated integers.
+// maxval: 0-D. Exclusive upper bound on the generated integers.
+//
+// Returns A tensor of the specified shape filled with uniform random integers.
+func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomUniformInt",
+ Input: []tf.Input{
+ shape, minval, maxval,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceApplyFtrlAttr is an optional argument to ResourceApplyFtrl.
type ResourceApplyFtrlAttr func(optionalAttr)
@@ -9113,6 +9923,29 @@ func DataFormatVecPermute(scope *Scope, x tf.Output, optional ...DataFormatVecPe
return op.Output(0)
}
+// Initializes the multi device iterator with the given dataset.
+//
+// Arguments:
+// dataset: Dataset to be iterated upon.
+// multi_device_iterator: A MultiDeviceIteratorResource.
+// max_buffer_size: The maximum size of the host side per device buffer to keep.
+//
+// Returns An int64 indicating which incarnation of the MultiDeviceIterator
+// is running.
+func MultiDeviceIteratorInit(scope *Scope, dataset tf.Output, multi_device_iterator tf.Output, max_buffer_size tf.Output) (incarnation_id tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIteratorInit",
+ Input: []tf.Input{
+ dataset, multi_device_iterator, max_buffer_size,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient of `igamma(a, x)` wrt `a`.
func IgammaGradA(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
if scope.Err() != nil {
@@ -9158,6 +9991,49 @@ func StringToHashBucket(scope *Scope, string_tensor tf.Output, num_buckets int64
return op.Output(0)
}
+// StaticRegexReplaceAttr is an optional argument to StaticRegexReplace.
+type StaticRegexReplaceAttr func(optionalAttr)
+
+// StaticRegexReplaceReplaceGlobal sets the optional replace_global attribute to value.
+//
+// value: If True, the replacement is global, otherwise the replacement
+// is done only on the first match.
+// If not specified, defaults to true
+func StaticRegexReplaceReplaceGlobal(value bool) StaticRegexReplaceAttr {
+ return func(m optionalAttr) {
+ m["replace_global"] = value
+ }
+}
+
+// Replaces the match of pattern in input with rewrite.
+//
+// It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+//
+// Arguments:
+// input: The text to be processed.
+// pattern: The regular expression to match the input.
+// rewrite: The rewrite to be applied to the matched expresion.
+//
+// Returns The text after applying pattern and rewrite.
+func StaticRegexReplace(scope *Scope, input tf.Output, pattern string, rewrite string, optional ...StaticRegexReplaceAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"pattern": pattern, "rewrite": rewrite}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StaticRegexReplace",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes gradients for the exponential linear (Elu) operation.
//
// Arguments:
@@ -10024,7 +10900,7 @@ func ResourceScatterNdAddUseLocking(value bool) ResourceScatterNdAddAttr {
//
// [1, 12, 3, 14, 14, 6, 7, 20]
//
-// See @{tf.scatter_nd} for more details about how to make updates to
+// See `tf.scatter_nd` for more details about how to make updates to
// slices.
//
// Arguments:
@@ -11335,36 +12211,43 @@ func FixedLengthRecordReaderV2(scope *Scope, record_bytes int64, optional ...Fix
return op.Output(0)
}
-// The gradient operator for the SparseAdd op.
+// StringLengthAttr is an optional argument to StringLength.
+type StringLengthAttr func(optionalAttr)
+
+// StringLengthUnit sets the optional unit attribute to value.
+// If not specified, defaults to "BYTE"
+func StringLengthUnit(value string) StringLengthAttr {
+ return func(m optionalAttr) {
+ m["unit"] = value
+ }
+}
+
+// String lengths of `input`.
//
-// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
-// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
-// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
-// values of A and B.
+// Computes the length of each string given in the input tensor.
//
// Arguments:
-// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to
-// the non-empty values of the sum.
-// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
-// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
-// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size
-// `[nnz(sum), ndims]`.
+// input: The string for which to compute the length.
//
-// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
-// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
-// non-empty values of B.
-func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
+// Returns Integer tensor that has the same shape as `input`. The output contains the
+// element-wise string lengths of `input`.
+func StringLength(scope *Scope, input tf.Output, optional ...StringLengthAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseAddGrad",
+ Type: "StringLength",
Input: []tf.Input{
- backprop_val_grad, a_indices, b_indices, sum_indices,
+ input,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
+ return op.Output(0)
}
// Converts each string in the input Tensor to its hash mod by a number of buckets.
@@ -11717,7 +12600,7 @@ func ResourceScatterNdUpdateUseLocking(value bool) ResourceScatterNdUpdateAttr {
//
// [1, 11, 3, 10, 9, 6, 7, 12]
//
-// See @{tf.scatter_nd} for more details about how to make updates to
+// See `tf.scatter_nd` for more details about how to make updates to
// slices.
//
// Arguments:
@@ -11746,6 +12629,26 @@ func ResourceScatterNdUpdate(scope *Scope, ref tf.Output, indices tf.Output, upd
return scope.AddOperation(opspec)
}
+// Produces a string handle for the given MultiDeviceIterator.
+//
+// Arguments:
+// multi_device_iterator: A MultiDeviceIterator resource.
+//
+// Returns A string representing the resource.
+func MultiDeviceIteratorToStringHandle(scope *Scope, multi_device_iterator tf.Output) (string_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIteratorToStringHandle",
+ Input: []tf.Input{
+ multi_device_iterator,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Applies softmax to a batched N-D `SparseTensor`.
//
// The inputs represent an N-D SparseTensor with logical shape `[..., B, C]`
@@ -12200,10 +13103,188 @@ func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
return op.Output(0)
}
+// StringFormatAttr is an optional argument to StringFormat.
+type StringFormatAttr func(optionalAttr)
+
+// StringFormatTemplate sets the optional template attribute to value.
+//
+// value: A string, the template to format tensor summaries into.
+// If not specified, defaults to "%s"
+func StringFormatTemplate(value string) StringFormatAttr {
+ return func(m optionalAttr) {
+ m["template"] = value
+ }
+}
+
+// StringFormatPlaceholder sets the optional placeholder attribute to value.
+//
+// value: A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+// If not specified, defaults to "%s"
+func StringFormatPlaceholder(value string) StringFormatAttr {
+ return func(m optionalAttr) {
+ m["placeholder"] = value
+ }
+}
+
+// StringFormatSummarize sets the optional summarize attribute to value.
+//
+// value: When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+// If not specified, defaults to 3
+func StringFormatSummarize(value int64) StringFormatAttr {
+ return func(m optionalAttr) {
+ m["summarize"] = value
+ }
+}
+
+// Formats a string template using a list of tensors.
+//
+// Formats a string template using a list of tensors, pretty-printing tensor summaries.
+//
+// Arguments:
+// inputs: The list of tensors to format into the placeholder string.
+//
+// Returns = The resulting string scalar.
+func StringFormat(scope *Scope, inputs []tf.Output, optional ...StringFormatAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringFormat",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ShapeAttr is an optional argument to Shape.
+type ShapeAttr func(optionalAttr)
+
+// ShapeOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func ShapeOutType(value tf.DataType) ShapeAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Returns the shape of a tensor.
+//
+// This operation returns a 1-D integer tensor representing the shape of `input`.
+//
+// For example:
+//
+// ```
+// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
+// shape(t) ==> [2, 2, 3]
+// ```
+func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Shape",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes the power of one value to another.
+//
+// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
+// corresponding elements in `x` and `y`. For example:
+//
+// ```
+// # tensor 'x' is [[2, 2]], [3, 3]]
+// # tensor 'y' is [[8, 16], [2, 3]]
+// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
+// ```
+func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Pow",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes fingerprints of the input strings.
+//
+// Arguments:
+// input: vector of strings to compute fingerprints on.
+//
+// Returns a (N,2) shaped matrix where N is the number of elements in the input
+// vector. Each row contains the low and high parts of the fingerprint.
+func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SdcaFprint",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// The gradient operator for the SparseAdd op.
+//
+// The SparseAdd op calculates A + B, where A, B, and the sum are all represented
+// as `SparseTensor` objects. This op takes in the upstream gradient w.r.t.
+// non-empty values of the sum, and outputs the gradients w.r.t. the non-empty
+// values of A and B.
+//
+// Arguments:
+// backprop_val_grad: 1-D with shape `[nnz(sum)]`. The gradient with respect to
+// the non-empty values of the sum.
+// a_indices: 2-D. The `indices` of the `SparseTensor` A, size `[nnz(A), ndims]`.
+// b_indices: 2-D. The `indices` of the `SparseTensor` B, size `[nnz(B), ndims]`.
+// sum_indices: 2-D. The `indices` of the sum `SparseTensor`, size
+// `[nnz(sum), ndims]`.
+//
+// Returns 1-D with shape `[nnz(A)]`. The gradient with respect to the
+// non-empty values of A.1-D with shape `[nnz(B)]`. The gradient with respect to the
+// non-empty values of B.
+func SparseAddGrad(scope *Scope, backprop_val_grad tf.Output, a_indices tf.Output, b_indices tf.Output, sum_indices tf.Output) (a_val_grad tf.Output, b_val_grad tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseAddGrad",
+ Input: []tf.Input{
+ backprop_val_grad, a_indices, b_indices, sum_indices,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// Computes the mean along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is
@@ -12218,7 +13299,7 @@ func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -12337,7 +13418,7 @@ func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, o
//
// Arguments:
// input: A string tensor of the text to be processed.
-// pattern: A 1-D string tensor of the regular expression to match the input.
+// pattern: A scalar string tensor containing the regular expression to match the input.
//
// Returns A bool tensor with the same shape as `input`.
func RegexFullMatch(scope *Scope, input tf.Output, pattern tf.Output) (output tf.Output) {
@@ -12391,6 +13472,79 @@ func InTopKV2(scope *Scope, predictions tf.Output, targets tf.Output, k tf.Outpu
return op.Output(0)
}
+// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
+type RandomPoissonV2Attr func(optionalAttr)
+
+// RandomPoissonV2Seed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// RandomPoissonV2Dtype sets the optional dtype attribute to value.
+// If not specified, defaults to DT_INT64
+func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
+ return func(m optionalAttr) {
+ m["dtype"] = value
+ }
+}
+
+// Outputs random values from the Poisson distribution(s) described by rate.
+//
+// This op uses two algorithms, depending on rate. If rate >= 10, then
+// the algorithm by Hormann is used to acquire samples via
+// transformation-rejection.
+// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
+//
+// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
+// random variables.
+// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
+// Programming, Volume 2. Addison Wesley
+//
+// Arguments:
+// shape: 1-D integer tensor. Shape of independent samples to draw from each
+// distribution described by the shape parameters given in rate.
+// rate: A tensor in which each scalar is a "rate" parameter describing the
+// associated poisson distribution.
+//
+// Returns A tensor with shape `shape + shape(rate)`. Each slice
+// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
+// `rate[i0, i1, ...iN]`.
+func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "RandomPoissonV2",
+ Input: []tf.Input{
+ shape, rate,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg.
type DecodeAndCropJpegAttr func(optionalAttr)
@@ -13892,34 +15046,6 @@ func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, label
return op.Output(0), op.Output(1)
}
-// Fast Fourier transform.
-//
-// Computes the 1-dimensional discrete Fourier transform over the inner-most
-// dimension of `input`.
-//
-// Arguments:
-// input: A complex64 tensor.
-//
-// Returns A complex64 tensor of the same shape as `input`. The inner-most
-// dimension of `input` is replaced with its 1D Fourier transform.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.fft
-// @end_compatibility
-func FFT(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "FFT",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Transforms a serialized tensorflow.TensorProto proto into a Tensor.
//
// Arguments:
@@ -14441,6 +15567,25 @@ func ResourceApplyProximalGradientDescent(scope *Scope, var_ tf.Output, alpha tf
return scope.AddOperation(opspec)
}
+// Returns 0 if the denominator is zero.
+//
+//
+// *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func DivNoNan(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DivNoNan",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient for the sqrt of `x` wrt its input.
//
// Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy`
@@ -15348,6 +16493,36 @@ func BytesProducedStatsDataset(scope *Scope, input_dataset tf.Output, tag tf.Out
return op.Output(0)
}
+// Check if the input matches the regex pattern.
+//
+// The input is a string tensor of any shape. The pattern is the
+// regular expression to be matched with every element of the input tensor.
+// The boolean values (True or False) of the output tensor indicate
+// if the input matches the regex pattern provided.
+//
+// The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+//
+// Arguments:
+// input: A string tensor of the text to be processed.
+// pattern: The regular expression to match the input.
+//
+// Returns A bool tensor with the same shape as `input`.
+func StaticRegexFullMatch(scope *Scope, input tf.Output, pattern string) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"pattern": pattern}
+ opspec := tf.OpSpec{
+ Type: "StaticRegexFullMatch",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent.
type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr)
@@ -15847,6 +17022,64 @@ func CudnnRNNBackprop(scope *Scope, input tf.Output, input_h tf.Output, input_c
return op.Output(0), op.Output(1), op.Output(2), op.Output(3)
}
+// UpperBoundAttr is an optional argument to UpperBound.
+type UpperBoundAttr func(optionalAttr)
+
+// UpperBoundOutType sets the optional out_type attribute to value.
+// If not specified, defaults to DT_INT32
+func UpperBoundOutType(value tf.DataType) UpperBoundAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Applies upper_bound(sorted_search_values, values) along each row.
+//
+// Each set of rows with the same index in (sorted_inputs, values) is treated
+// independently. The resulting row is the equivalent of calling
+// `np.searchsorted(sorted_inputs, values, side='right')`.
+//
+// The result is not a global index to the entire
+// `Tensor`, but rather just the index in the last dimension.
+//
+// A 2-D example:
+// sorted_sequence = [[0, 3, 9, 9, 10],
+// [1, 2, 3, 4, 5]]
+// values = [[2, 4, 9],
+// [0, 2, 6]]
+//
+// result = UpperBound(sorted_sequence, values)
+//
+// result == [[1, 2, 4],
+// [0, 2, 5]]
+//
+// Arguments:
+// sorted_inputs: 2-D Tensor where each row is ordered.
+// values: 2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+// the values that will be searched for in `sorted_search_values`.
+//
+// Returns A `Tensor` with the same shape as `values`. It contains the last scalar index
+// into the last dimension where values can be inserted without changing the
+// ordered property.
+func UpperBound(scope *Scope, sorted_inputs tf.Output, values tf.Output, optional ...UpperBoundAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "UpperBound",
+ Input: []tf.Input{
+ sorted_inputs, values,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// FractionalMaxPoolGradAttr is an optional argument to FractionalMaxPoolGrad.
type FractionalMaxPoolGradAttr func(optionalAttr)
@@ -15945,6 +17178,23 @@ func ResourceApplyAdagradDA(scope *Scope, var_ tf.Output, gradient_accumulator t
return scope.AddOperation(opspec)
}
+// Creates a dataset containing elements of first component of `input_dataset` having true in the last component.
+func FilterByLastComponentDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "FilterByLastComponentDataset",
+ Input: []tf.Input{
+ input_dataset,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// CudnnRNNCanonicalToParamsAttr is an optional argument to CudnnRNNCanonicalToParams.
type CudnnRNNCanonicalToParamsAttr func(optionalAttr)
@@ -16804,7 +18054,8 @@ func DecodeCSVSelectCols(value []int64) DecodeCSVAttr {
// records: Each string is a record/row in the csv and all records should have
// the same format.
// record_defaults: One tensor per column of the input record, with either a
-// scalar default value for that column or empty if the column is required.
+// scalar default value for that column or an empty vector if the column is
+// required.
//
// Returns Each tensor will have the same shape as records.
func DecodeCSV(scope *Scope, records tf.Output, record_defaults []tf.Output, optional ...DecodeCSVAttr) (output []tf.Output) {
@@ -17571,8 +18822,9 @@ func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_
// Computes the sum along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \sum_j data_j\\) where sum is over `j` such
@@ -17586,7 +18838,7 @@ func ReaderNumRecordsProducedV2(scope *Scope, reader_handle tf.Output) (records_
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -18812,27 +20064,6 @@ func OptimizeDataset(scope *Scope, input_dataset tf.Output, optimizations tf.Out
return op.Output(0)
}
-// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
-// layer.
-func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesGetEnsembleStates",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// Returns the element-wise min of two SparseTensors.
//
// Assumes the two SparseTensors have the same shape, i.e., no broadcasting.
@@ -19503,8 +20734,9 @@ func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min
// Computes the minimum along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such
@@ -19518,7 +20750,7 @@ func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -19632,164 +20864,6 @@ func SdcaOptimizer(scope *Scope, sparse_example_indices []tf.Output, sparse_feat
return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights
}
-// ShapeAttr is an optional argument to Shape.
-type ShapeAttr func(optionalAttr)
-
-// ShapeOutType sets the optional out_type attribute to value.
-// If not specified, defaults to DT_INT32
-func ShapeOutType(value tf.DataType) ShapeAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Returns the shape of a tensor.
-//
-// This operation returns a 1-D integer tensor representing the shape of `input`.
-//
-// For example:
-//
-// ```
-// # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]]
-// shape(t) ==> [2, 2, 3]
-// ```
-func Shape(scope *Scope, input tf.Output, optional ...ShapeAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Shape",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes the power of one value to another.
-//
-// Given a tensor `x` and a tensor `y`, this operation computes \\(x^y\\) for
-// corresponding elements in `x` and `y`. For example:
-//
-// ```
-// # tensor 'x' is [[2, 2]], [3, 3]]
-// # tensor 'y' is [[8, 16], [2, 3]]
-// tf.pow(x, y) ==> [[256, 65536], [9, 27]]
-// ```
-func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Pow",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes fingerprints of the input strings.
-//
-// Arguments:
-// input: vector of strings to compute fingerprints on.
-//
-// Returns a (N,2) shaped matrix where N is the number of elements in the input
-// vector. Each row contains the low and high parts of the fingerprint.
-func SdcaFprint(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SdcaFprint",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// RandomPoissonV2Attr is an optional argument to RandomPoissonV2.
-type RandomPoissonV2Attr func(optionalAttr)
-
-// RandomPoissonV2Seed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed(value int64) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomPoissonV2Seed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomPoissonV2Seed2(value int64) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// RandomPoissonV2Dtype sets the optional dtype attribute to value.
-// If not specified, defaults to DT_INT64
-func RandomPoissonV2Dtype(value tf.DataType) RandomPoissonV2Attr {
- return func(m optionalAttr) {
- m["dtype"] = value
- }
-}
-
-// Outputs random values from the Poisson distribution(s) described by rate.
-//
-// This op uses two algorithms, depending on rate. If rate >= 10, then
-// the algorithm by Hormann is used to acquire samples via
-// transformation-rejection.
-// See http://www.sciencedirect.com/science/article/pii/0167668793909974.
-//
-// Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform
-// random variables.
-// See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer
-// Programming, Volume 2. Addison Wesley
-//
-// Arguments:
-// shape: 1-D integer tensor. Shape of independent samples to draw from each
-// distribution described by the shape parameters given in rate.
-// rate: A tensor in which each scalar is a "rate" parameter describing the
-// associated poisson distribution.
-//
-// Returns A tensor with shape `shape + shape(rate)`. Each slice
-// `[:, ..., :, i0, i1, ...iN]` contains the samples drawn for
-// `rate[i0, i1, ...iN]`.
-func RandomPoissonV2(scope *Scope, shape tf.Output, rate tf.Output, optional ...RandomPoissonV2Attr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "RandomPoissonV2",
- Input: []tf.Input{
- shape, rate,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve.
type MatrixTriangularSolveAttr func(optionalAttr)
@@ -20264,27 +21338,31 @@ func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
// Computes the product along segments of a tensor.
//
-// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// for an explanation of segments.
//
// This operator is similar to the unsorted segment sum operator found
// [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
// Instead of computing the sum over segments, it computes the product of all
// entries belonging to a segment such that:
//
-// \\(output_i = \prod_j data_j\\) where the product is over `j` such
-// that `segment_ids[j] == i`.
+// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
+// `j...` such that `segment_ids[j...] == i`.
//
// If there is no entry for a given segment ID `i`, it outputs 1.
//
+// If the given segment ID `i` is negative, then the corresponding value is
+// dropped, and will not be included in the result.
+//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
-// first dimension.
+// segment_ids: A tensor whose shape is a prefix of `data.shape`.
//
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `num_segments`.
+// Returns Has same shape as data, except for the first `segment_ids.rank`
+// dimensions, which are replaced with a single dimension which has size
+// `num_segments`.
func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, num_segments tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
@@ -20299,90 +21377,172 @@ func UnsortedSegmentProd(scope *Scope, data tf.Output, segment_ids tf.Output, nu
return op.Output(0)
}
-// RandomUniformIntAttr is an optional argument to RandomUniformInt.
-type RandomUniformIntAttr func(optionalAttr)
-
-// RandomUniformIntSeed sets the optional seed attribute to value.
+// Computes the mean along sparse segments of a tensor.
//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func RandomUniformIntSeed(value int64) RandomUniformIntAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// RandomUniformIntSeed2 sets the optional seed2 attribute to value.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func RandomUniformIntSeed2(value int64) RandomUniformIntAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
+// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
+// dimension, selecting a subset of dimension 0, specified by `indices`.
+//
+// Arguments:
+//
+// indices: A 1-D tensor. Has same rank as `segment_ids`.
+// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseSegmentMean",
+ Input: []tf.Input{
+ data, indices, segment_ids,
+ },
}
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Outputs random integers from a uniform distribution.
-//
-// The generated values are uniform integers in the range `[minval, maxval)`.
-// The lower bound `minval` is included in the range, while the upper bound
-// `maxval` is excluded.
+// Deserializes a serialized tree ensemble config and replaces current tree
//
-// The random integers are slightly biased unless `maxval - minval` is an exact
-// power of two. The bias is small for values of `maxval - minval` significantly
-// smaller than the range of the output (either `2^32` or `2^64`).
+// ensemble.
//
// Arguments:
-// shape: The shape of the output tensor.
-// minval: 0-D. Inclusive lower bound on the generated integers.
-// maxval: 0-D. Exclusive upper bound on the generated integers.
+// tree_ensemble_handle: Handle to the tree ensemble.
+// stamp_token: Token to use as the new value of the resource stamp.
+// tree_ensemble_serialized: Serialized proto of the ensemble.
//
-// Returns A tensor of the specified shape filled with uniform random integers.
-func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf.Output, optional ...RandomUniformIntAttr) (output tf.Output) {
+// Returns the created operation.
+func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesDeserializeEnsemble",
+ Input: []tf.Input{
+ tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
+ },
}
+ return scope.AddOperation(opspec)
+}
+
+// Transforms a tf.Example proto (as a string) into typed tensors.
+//
+// Arguments:
+// serialized: A vector containing a batch of binary serialized Example protos.
+// dense_defaults: A list of Tensors (some may be empty), whose length matches
+// the length of `dense_keys`. dense_defaults[j] provides default values
+// when the example's feature_map lacks dense_key[j]. If an empty Tensor is
+// provided for dense_defaults[j], then the Feature dense_keys[j] is required.
+// The input type is inferred from dense_defaults[j], even when it's empty.
+// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,
+// then the shape of dense_defaults[j] must match that of dense_shapes[j].
+// If dense_shapes[j] has an undefined major dimension (variable strides dense
+// feature), dense_defaults[j] must contain a single element:
+// the padding element.
+// num_sparse: The number of sparse features to be parsed from the example. This
+// must match the lengths of `sparse_keys` and `sparse_types`.
+// sparse_keys: A list of `num_sparse` strings.
+// The keys expected in the Examples' features associated with sparse values.
+// dense_keys: The keys expected in the Examples' features associated with dense
+// values.
+// sparse_types: A list of `num_sparse` types; the data types of data in each
+// Feature given in sparse_keys.
+// Currently the ParseSingleExample op supports DT_FLOAT (FloatList),
+// DT_INT64 (Int64List), and DT_STRING (BytesList).
+// dense_shapes: The shapes of data in each Feature given in dense_keys.
+// The length of this list must match the length of `dense_keys`. The
+// number of elements in the Feature corresponding to dense_key[j] must
+// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] ==
+// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j]
+// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1,
+// ..., DN), the shape of the output Tensor dense_values[j] will be (M,
+// D1, .., DN), where M is the number of blocks of elements of length
+// D1 * .... * DN, in the input.
+func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes}
opspec := tf.OpSpec{
- Type: "RandomUniformInt",
+ Type: "ParseSingleExample",
Input: []tf.Input{
- shape, minval, maxval,
+ serialized, tf.OutputList(dense_defaults),
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil {
+ scope.UpdateErr("ParseSingleExample", err)
+ return
+ }
+ return sparse_indices, sparse_values, sparse_shapes, dense_values
}
-// Computes the mean along sparse segments of a tensor.
-//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
+type WholeFileReaderV2Attr func(optionalAttr)
+
+// WholeFileReaderV2Container sets the optional container attribute to value.
//
-// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
-// dimension, selecting a subset of dimension 0, specified by `indices`.
+// value: If non-empty, this reader is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
//
-// Arguments:
+// value: If non-empty, this reader is named in the given bucket
+// with this shared_name. Otherwise, the node name is used instead.
+// If not specified, defaults to ""
+func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// A Reader that outputs the entire contents of a file as a value.
//
-// indices: A 1-D tensor. Has same rank as `segment_ids`.
-// segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+// To use, enqueue filenames in a Queue. The output of ReaderRead will
+// be a filename (key) and the contents of that file (value).
//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SparseSegmentMean(scope *Scope, data tf.Output, indices tf.Output, segment_ids tf.Output) (output tf.Output) {
+// Returns The handle to reference the Reader.
+func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
- Type: "SparseSegmentMean",
- Input: []tf.Input{
- data, indices, segment_ids,
- },
+ Type: "WholeFileReaderV2",
+
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -20431,8 +21591,9 @@ func Cosh(scope *Scope, x tf.Output) (y tf.Output) {
// Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
// misisng, the `output` tensor at that position will be zeroed.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Arguments:
//
@@ -20577,8 +21738,9 @@ func SparseSegmentMeanGrad(scope *Scope, grad tf.Output, indices tf.Output, segm
//
// N is the size of the segment being reduced.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Arguments:
//
@@ -20636,8 +21798,9 @@ func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
// Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
// misisng, the `output` tensor at that position will be zeroed.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Arguments:
//
@@ -20998,8 +22161,9 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output
// Computes the maximum along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
@@ -21013,7 +22177,7 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -22364,6 +23528,58 @@ func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, o
return op.Output(0)
}
+// MultiDeviceIteratorFromStringHandleAttr is an optional argument to MultiDeviceIteratorFromStringHandle.
+type MultiDeviceIteratorFromStringHandleAttr func(optionalAttr)
+
+// MultiDeviceIteratorFromStringHandleOutputTypes sets the optional output_types attribute to value.
+//
+// value: The type list for the return values.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func MultiDeviceIteratorFromStringHandleOutputTypes(value []tf.DataType) MultiDeviceIteratorFromStringHandleAttr {
+ return func(m optionalAttr) {
+ m["output_types"] = value
+ }
+}
+
+// MultiDeviceIteratorFromStringHandleOutputShapes sets the optional output_shapes attribute to value.
+//
+// value: The list of shapes being produced.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func MultiDeviceIteratorFromStringHandleOutputShapes(value []tf.Shape) MultiDeviceIteratorFromStringHandleAttr {
+ return func(m optionalAttr) {
+ m["output_shapes"] = value
+ }
+}
+
+// Generates a MultiDeviceIterator resource from its provided string handle.
+//
+// Arguments:
+// string_handle: String representing the resource.
+//
+// Returns A MultiDeviceIterator resource.
+func MultiDeviceIteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ...MultiDeviceIteratorFromStringHandleAttr) (multi_device_iterator tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIteratorFromStringHandle",
+ Input: []tf.Input{
+ string_handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// MutableHashTableV2Attr is an optional argument to MutableHashTableV2.
type MutableHashTableV2Attr func(optionalAttr)
@@ -23429,29 +24645,57 @@ func TensorListSetItem(scope *Scope, input_handle tf.Output, index tf.Output, it
return op.Output(0)
}
-// Computes the matrix exponential of one or more square matrices:
+// Creates a Tensor by indexing into the TensorList.
//
-// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead.
+// Each row in the produced Tensor corresponds to the element in the TensorList
+// specified by the given index (see `tf.gather`).
//
-// \\(exp(A) = \sum_{n=0}^\infty A^n/n!\\)
-//
-// The exponential is computed using a combination of the scaling and squaring
-// method and the Pade approximation. Details can be founds in:
-// Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
-// revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
-//
-// The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-// form square matrices. The output is a tensor of the same shape as the input
-// containing the exponential for all input submatrices `[..., :, :]`.
+// input_handle: The input tensor list.
+// indices: The indices used to index into the list.
+// values: The tensor.
+func TensorListGather(scope *Scope, input_handle tf.Output, indices tf.Output, element_dtype tf.DataType) (values tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"element_dtype": element_dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorListGather",
+ Input: []tf.Input{
+ input_handle, indices,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a TensorList by indexing into a Tensor.
//
-// Arguments:
-// input: Shape is `[..., M, M]`.
+// Each member of the TensorList corresponds to one row of the input tensor,
+// specified by the given index (see `tf.gather`).
//
-// Returns Shape is `[..., M, M]`.
+// tensor: The input tensor.
+// indices: The indices used to index into the list.
+// element_shape: The shape of the elements in the list (can be less specified than
+// the shape of the tensor).
+// output_handle: The TensorList.
+func TensorListScatter(scope *Scope, tensor tf.Output, indices tf.Output, element_shape tf.Output) (output_handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorListScatter",
+ Input: []tf.Input{
+ tensor, indices, element_shape,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Deprecated, use python implementation tf.linalg.matrix_exponential.
//
-// @compatibility(scipy)
-// Equivalent to scipy.linalg.expm
-// @end_compatibility
+// DEPRECATED at GraphDef version 27: Use Python implementation tf.linalg.matrix_exponential instead.
func MatrixExponential(scope *Scope, input tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
@@ -23904,6 +25148,45 @@ func Svd(scope *Scope, input tf.Output, optional ...SvdAttr) (s tf.Output, u tf.
return op.Output(0), op.Output(1), op.Output(2)
}
+// PrintV2Attr is an optional argument to PrintV2.
+type PrintV2Attr func(optionalAttr)
+
+// PrintV2OutputStream sets the optional output_stream attribute to value.
+//
+// value: A string specifying the output stream or logging level to print to.
+// If not specified, defaults to "stderr"
+func PrintV2OutputStream(value string) PrintV2Attr {
+ return func(m optionalAttr) {
+ m["output_stream"] = value
+ }
+}
+
+// Prints a string scalar.
+//
+// Prints a string scalar to the desired output_stream.
+//
+// Arguments:
+// input: The string scalar to print.
+//
+// Returns the created operation.
+func PrintV2(scope *Scope, input tf.Output, optional ...PrintV2Attr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "PrintV2",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
// QueueEnqueueManyV2Attr is an optional argument to QueueEnqueueManyV2.
type QueueEnqueueManyV2Attr func(optionalAttr)
@@ -23957,8 +25240,9 @@ func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output,
// Computes the product along segments of a tensor.
//
-// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-// segments.
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
//
// Computes a tensor such that
// \\(output_i = \prod_j data_j\\) where the product is over `j` such
@@ -23972,7 +25256,7 @@ func QueueEnqueueManyV2(scope *Scope, handle tf.Output, components []tf.Output,
//
// Arguments:
//
-// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
// first dimension. Values should be sorted and can be repeated.
//
// Returns Has same shape as data, except for dimension 0 which
@@ -24997,7 +26281,7 @@ func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr {
// Update '*var' according to the Adam algorithm.
//
-// $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+// $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
// $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
// $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
// $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
@@ -26636,36 +27920,6 @@ func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset t
return op.Output(0)
}
-// Debugging/model interpretability outputs for each example.
-//
-// It traverses all the trees and computes debug metrics for individual examples,
-// such as getting split feature ids and logits after each split along the decision
-// path used to compute directional feature contributions.
-//
-// Arguments:
-//
-// bucketized_features: A list of rank 1 Tensors containing bucket id for each
-// feature.
-// logits_dimension: scalar, dimension of the logits, to be used for constructing the protos in
-// examples_debug_outputs_serialized.
-//
-// Returns Output rank 1 Tensor containing a proto serialized as a string for each example.
-func BoostedTreesExampleDebugOutputs(scope *Scope, tree_ensemble_handle tf.Output, bucketized_features []tf.Output, logits_dimension int64) (examples_debug_outputs_serialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"logits_dimension": logits_dimension}
- opspec := tf.OpSpec{
- Type: "BoostedTreesExampleDebugOutputs",
- Input: []tf.Input{
- tree_ensemble_handle, tf.OutputList(bucketized_features),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adds a value to the current value of a variable.
//
// Any ReadVariableOp with a control dependency on this op is guaranteed to
@@ -27044,8 +28298,10 @@ func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source
// If `len` defines a substring that would extend beyond the length of the input
// string, then as many characters as possible are used.
//
-// If `pos` is negative or specifies a character index larger than any of the input
-// strings, then an `InvalidArgumentError` is thrown.
+// A negative `pos` indicates distance within the string backwards from the end.
+//
+// If `pos` specifies an index which is out of range for any of the input strings,
+// then an `InvalidArgumentError` is thrown.
//
// `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
// Op creation.
@@ -27450,35 +28706,6 @@ func MakeIterator(scope *Scope, dataset tf.Output, iterator tf.Output) (o *tf.Op
return scope.AddOperation(opspec)
}
-// Makes the summary of accumulated stats for the batch.
-//
-// The summary stats contains gradients and hessians accumulated into the corresponding node and bucket for each example.
-//
-// Arguments:
-// node_ids: int32 Rank 1 Tensor containing node ids, which each example falls into for the requested layer.
-// gradients: float32; Rank 2 Tensor (shape=[#examples, 1]) for gradients.
-// hessians: float32; Rank 2 Tensor (shape=[#examples, 1]) for hessians.
-// bucketized_features_list: int32 list of Rank 1 Tensors, each containing the bucketized feature (for each feature column).
-// max_splits: int; the maximum number of splits possible in the whole tree.
-// num_buckets: int; equals to the maximum possible value of bucketized feature.
-//
-// Returns output Rank 4 Tensor (shape=[#features, #splits, #buckets, 2]) containing accumulated stats put into the corresponding node and bucket. The first index of 4th dimension refers to gradients, and the second to hessians.
-func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf.Output, hessians tf.Output, bucketized_features_list []tf.Output, max_splits int64, num_buckets int64) (stats_summary tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"max_splits": max_splits, "num_buckets": num_buckets}
- opspec := tf.OpSpec{
- Type: "BoostedTreesMakeStatsSummary",
- Input: []tf.Input{
- node_ids, gradients, hessians, tf.OutputList(bucketized_features_list),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adjust the contrast of one or more images.
//
// `images` is a tensor of at least 3 dimensions. The last 3 dimensions are
@@ -27671,6 +28898,8 @@ func IteratorFromStringHandle(scope *Scope, string_handle tf.Output, optional ..
// On GPU, if an out of bound index is found, a 0 is stored in the
// corresponding output value.
//
+// See also `tf.batch_gather` and `tf.gather_nd`.
+//
// Arguments:
// params: The tensor from which to gather values. Must be at least rank
// `axis + 1`.
@@ -28153,6 +29382,58 @@ func IteratorGetNextAsOptional(scope *Scope, iterator tf.Output, output_types []
return op.Output(0)
}
+// Fast Fourier transform.
+//
+// Computes the 1-dimensional discrete Fourier transform over the inner-most
+// dimension of `input`.
+//
+// Arguments:
+// input: A complex64 tensor.
+//
+// Returns A complex64 tensor of the same shape as `input`. The inner-most
+// dimension of `input` is replaced with its 1D Fourier transform.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.fft
+// @end_compatibility
+func FFT(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "FFT",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Identity transformation that models performance.
+//
+// Identity transformation that models performance.
+//
+// Arguments:
+// input_dataset: A variant tensor representing the input dataset.
+//
+//
+func ModelDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "ModelDataset",
+ Input: []tf.Input{
+ input_dataset,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Performs a padding as a preprocess during a convolution.
//
// Similar to FusedResizeAndPadConv2d, this op allows for an optimized
@@ -28842,10 +30123,16 @@ func EncodeBase64(scope *Scope, input tf.Output, optional ...EncodeBase64Attr) (
//
// Arguments:
//
-// window_size: A scalar representing the number of elements to accumulate in a window.
+// size: A scalar representing the number of elements to accumulate in a window.
+// shift: A scalar representing the steps moving the sliding window forward in one
+// iteration. It must be positive.
+// stride: A scalar representing the stride of the input elements of the sliding window.
+// It must be positive.
+// drop_remainder: A scalar representing whether a window should be dropped in case its size is
+// smaller than desired.
//
//
-func WindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+func WindowDataset(scope *Scope, input_dataset tf.Output, size tf.Output, shift tf.Output, stride tf.Output, drop_remainder tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
return
}
@@ -28853,7 +30140,7 @@ func WindowDataset(scope *Scope, input_dataset tf.Output, window_size tf.Output,
opspec := tf.OpSpec{
Type: "WindowDataset",
Input: []tf.Input{
- input_dataset, window_size,
+ input_dataset, size, shift, stride, drop_remainder,
},
Attrs: attrs,
}
@@ -30008,27 +31295,6 @@ func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, val
return op.Output(0)
}
-// Creates a tree ensemble model and returns a handle to it.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble resource to be created.
-// stamp_token: Token to use as the initial value of the resource stamp.
-// tree_ensemble_serialized: Serialized proto of the tree ensemble.
-//
-// Returns the created operation.
-func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesCreateEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Applies sparse addition to `input` using individual values or slices
//
// from `updates` according to indices `indices`. The updates are non-aliasing:
@@ -30063,7 +31329,7 @@ func BoostedTreesCreateEnsemble(scope *Scope, tree_ensemble_handle tf.Output, st
//
// [1, 13, 3, 14, 14, 6, 7, 20]
//
-// See @{tf.scatter_nd} for more details about how to make updates to slices.
+// See `tf.scatter_nd` for more details about how to make updates to slices.
//
// Arguments:
// input: A Tensor.
@@ -30216,6 +31482,32 @@ func FractionalMaxPool(scope *Scope, value tf.Output, pooling_ratio []float32, o
return op.Output(0), op.Output(1), op.Output(2)
}
+// Creates a MultiDeviceIterator resource.
+//
+// Arguments:
+// devices: A list of devices the iterator works across.
+// shared_name: If non-empty, this resource will be shared under the given name
+// across multiple sessions.
+// container: If non-empty, this resource is placed in the given container.
+// Otherwise, a default container is used.
+// output_types: The type list for the return values.
+// output_shapes: The list of shapes being produced.
+//
+// Returns Handle to the resource created.
+func MultiDeviceIterator(scope *Scope, devices []string, shared_name string, container string, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"devices": devices, "shared_name": shared_name, "container": container, "output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "MultiDeviceIterator",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Deprecated. Use TensorArraySizeV3
//
// DEPRECATED at GraphDef version 26: Use TensorArraySizeV3
@@ -30680,6 +31972,41 @@ func MapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...MapIncomp
return op.Output(0)
}
+// Generate the bucket boundaries for each feature based on accumulated summaries.
+//
+// An op that returns a list of float tensors for a quantile stream resource. Each
+// tensor is Rank 1 containing bucket boundaries for a single feature.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+// num_features: inferred int; number of features to get bucket boundaries for.
+//
+// Returns float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+func BoostedTreesQuantileStreamResourceGetBucketBoundaries(scope *Scope, quantile_stream_resource_handle tf.Output, num_features int64) (bucket_boundaries []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_features": num_features}
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceGetBucketBoundaries",
+ Input: []tf.Input{
+ quantile_stream_resource_handle,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if bucket_boundaries, idx, err = makeOutputList(op, idx, "bucket_boundaries"); err != nil {
+ scope.UpdateErr("BoostedTreesQuantileStreamResourceGetBucketBoundaries", err)
+ return
+ }
+ return bucket_boundaries
+}
+
// OrderedMapUnstageAttr is an optional argument to OrderedMapUnstage.
type OrderedMapUnstageAttr func(optionalAttr)
@@ -30751,6 +32078,43 @@ func OrderedMapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []
return values
}
+// BoostedTreesQuantileStreamResourceHandleOpAttr is an optional argument to BoostedTreesQuantileStreamResourceHandleOp.
+type BoostedTreesQuantileStreamResourceHandleOpAttr func(optionalAttr)
+
+// BoostedTreesQuantileStreamResourceHandleOpContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesQuantileStreamResourceHandleOpContainer(value string) BoostedTreesQuantileStreamResourceHandleOpAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// BoostedTreesQuantileStreamResourceHandleOpSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func BoostedTreesQuantileStreamResourceHandleOpSharedName(value string) BoostedTreesQuantileStreamResourceHandleOpAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Creates a handle to a BoostedTreesQuantileStreamResource.
+func BoostedTreesQuantileStreamResourceHandleOp(scope *Scope, optional ...BoostedTreesQuantileStreamResourceHandleOpAttr) (resource tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceHandleOp",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// OrderedMapSizeAttr is an optional argument to OrderedMapSize.
type OrderedMapSizeAttr func(optionalAttr)
@@ -31777,144 +33141,3 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1), op.Output(2)
}
-
-// Transforms a tf.Example proto (as a string) into typed tensors.
-//
-// Arguments:
-// serialized: A vector containing a batch of binary serialized Example protos.
-// dense_defaults: A list of Tensors (some may be empty), whose length matches
-// the length of `dense_keys`. dense_defaults[j] provides default values
-// when the example's feature_map lacks dense_key[j]. If an empty Tensor is
-// provided for dense_defaults[j], then the Feature dense_keys[j] is required.
-// The input type is inferred from dense_defaults[j], even when it's empty.
-// If dense_defaults[j] is not empty, and dense_shapes[j] is fully defined,
-// then the shape of dense_defaults[j] must match that of dense_shapes[j].
-// If dense_shapes[j] has an undefined major dimension (variable strides dense
-// feature), dense_defaults[j] must contain a single element:
-// the padding element.
-// num_sparse: The number of sparse features to be parsed from the example. This
-// must match the lengths of `sparse_keys` and `sparse_types`.
-// sparse_keys: A list of `num_sparse` strings.
-// The keys expected in the Examples' features associated with sparse values.
-// dense_keys: The keys expected in the Examples' features associated with dense
-// values.
-// sparse_types: A list of `num_sparse` types; the data types of data in each
-// Feature given in sparse_keys.
-// Currently the ParseSingleExample op supports DT_FLOAT (FloatList),
-// DT_INT64 (Int64List), and DT_STRING (BytesList).
-// dense_shapes: The shapes of data in each Feature given in dense_keys.
-// The length of this list must match the length of `dense_keys`. The
-// number of elements in the Feature corresponding to dense_key[j] must
-// always equal dense_shapes[j].NumEntries(). If dense_shapes[j] ==
-// (D0, D1, ..., DN) then the shape of output Tensor dense_values[j]
-// will be (D0, D1, ..., DN): In the case dense_shapes[j] = (-1, D1,
-// ..., DN), the shape of the output Tensor dense_values[j] will be (M,
-// D1, .., DN), where M is the number of blocks of elements of length
-// D1 * .... * DN, in the input.
-func ParseSingleExample(scope *Scope, serialized tf.Output, dense_defaults []tf.Output, num_sparse int64, sparse_keys []string, dense_keys []string, sparse_types []tf.DataType, dense_shapes []tf.Shape) (sparse_indices []tf.Output, sparse_values []tf.Output, sparse_shapes []tf.Output, dense_values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_sparse": num_sparse, "sparse_keys": sparse_keys, "dense_keys": dense_keys, "sparse_types": sparse_types, "dense_shapes": dense_shapes}
- opspec := tf.OpSpec{
- Type: "ParseSingleExample",
- Input: []tf.Input{
- serialized, tf.OutputList(dense_defaults),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if sparse_indices, idx, err = makeOutputList(op, idx, "sparse_indices"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if sparse_values, idx, err = makeOutputList(op, idx, "sparse_values"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if sparse_shapes, idx, err = makeOutputList(op, idx, "sparse_shapes"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- if dense_values, idx, err = makeOutputList(op, idx, "dense_values"); err != nil {
- scope.UpdateErr("ParseSingleExample", err)
- return
- }
- return sparse_indices, sparse_values, sparse_shapes, dense_values
-}
-
-// WholeFileReaderV2Attr is an optional argument to WholeFileReaderV2.
-type WholeFileReaderV2Attr func(optionalAttr)
-
-// WholeFileReaderV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this reader is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func WholeFileReaderV2Container(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// WholeFileReaderV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this reader is named in the given bucket
-// with this shared_name. Otherwise, the node name is used instead.
-// If not specified, defaults to ""
-func WholeFileReaderV2SharedName(value string) WholeFileReaderV2Attr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// A Reader that outputs the entire contents of a file as a value.
-//
-// To use, enqueue filenames in a Queue. The output of ReaderRead will
-// be a filename (key) and the contents of that file (value).
-//
-// Returns The handle to reference the Reader.
-func WholeFileReaderV2(scope *Scope, optional ...WholeFileReaderV2Attr) (reader_handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "WholeFileReaderV2",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Deserializes a serialized tree ensemble config and replaces current tree
-//
-// ensemble.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-// stamp_token: Token to use as the new value of the resource stamp.
-// tree_ensemble_serialized: Serialized proto of the ensemble.
-//
-// Returns the created operation.
-func BoostedTreesDeserializeEnsemble(scope *Scope, tree_ensemble_handle tf.Output, stamp_token tf.Output, tree_ensemble_serialized tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesDeserializeEnsemble",
- Input: []tf.Input{
- tree_ensemble_handle, stamp_token, tree_ensemble_serialized,
- },
- }
- return scope.AddOperation(opspec)
-}
diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md
index c7382ff231..7ef862ae79 100644
--- a/tensorflow/java/README.md
+++ b/tensorflow/java/README.md
@@ -10,7 +10,7 @@
## Quickstart
-- Refer to [Installing TensorFlow for Java](https://www.tensorflow.org/install/install_java)
+- Refer to [Installing TensorFlow for Java](https://www.tensorflow.org/install/lang_java)
- [Javadoc](https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary)
- [![Maven Central](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow/badge.svg)](https://maven-badges.herokuapp.com/maven-central/org.tensorflow/tensorflow)
@@ -22,8 +22,7 @@ native libraries will need to be built from source.
1. Install [bazel](https://www.bazel.build/versions/master/docs/install.html)
2. Setup the environment to build TensorFlow from source code
- ([Linux](https://www.tensorflow.org/install/install_sources#PrepareLinux)
- or [macOS](https://www.tensorflow.org/install/install_sources#PrepareMac)).
+ ([Linux or macOS](https://www.tensorflow.org/install/source)).
If you'd like to skip reading those details and do not care about GPU
support, try the following:
@@ -35,7 +34,7 @@ native libraries will need to be built from source.
brew install swig
```
-3. [Configure](https://www.tensorflow.org/install/install_sources#configure_the_installation)
+3. [Configure](https://www.tensorflow.org/install/source)
(e.g., enable GPU support) and build:
```sh
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index f9093ce385..9fc6969c20 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 1208956dec..68712082e1 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 755449cb3c..f031173c99 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index e1bf2c7dba..2cac27990e 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0-rc2</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index b89f042567..8a93091276 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
index 1b7995be2c..014bd8d212 100644
--- a/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
+++ b/tensorflow/java/maven/spark-tensorflow-connector/pom.xml
@@ -6,7 +6,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>spark-tensorflow-connector_2.11</artifactId>
<packaging>jar</packaging>
- <version>1.10.0</version>
+ <version>1.11.0-rc2</version>
<name>spark-tensorflow-connector</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
diff --git a/tensorflow/java/maven/tensorflow-hadoop/pom.xml b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
index 0fe6f4dce4..d07c5fcd98 100644
--- a/tensorflow/java/maven/tensorflow-hadoop/pom.xml
+++ b/tensorflow/java/maven/tensorflow-hadoop/pom.xml
@@ -5,7 +5,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-hadoop</artifactId>
<packaging>jar</packaging>
- <version>1.10.0</version>
+ <version>1.11.0-rc2</version>
<name>tensorflow-hadoop</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 0de90244b1..af0c68a4ed 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.10.0</version>
+ <version>1.11.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 19729813a1..79f14466e6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1998,6 +1998,29 @@ py_library(
)
py_library(
+ name = "while_v2",
+ srcs = [
+ "ops/while_v2.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":cond_v2_impl",
+ ":constant_op",
+ ":control_flow_util",
+ ":framework_ops",
+ ":function_def_to_graph",
+ ":functional_ops_gen",
+ ":gradients_impl",
+ ":list_ops",
+ ":tensor_shape",
+ ":util",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/eager:function",
+ ],
+)
+
+py_library(
name = "cond_v2_impl",
srcs = [
"ops/cond_v2_impl.py",
@@ -2301,6 +2324,8 @@ py_library(
deps = [
":framework_for_generated_wrappers",
":logging_ops_gen",
+ ":platform",
+ ":string_ops",
":util",
],
)
@@ -3058,6 +3083,7 @@ cuda_py_test(
":functional_ops",
":gradients",
":layers",
+ ":list_ops",
":math_grad",
":math_ops",
":nn_grad",
@@ -3089,7 +3115,7 @@ cuda_py_test(
cuda_py_test(
name = "image_grad_test",
- size = "small",
+ size = "medium",
srcs = ["ops/image_grad_test.py"],
additional_deps = [
":client_testlib",
@@ -3737,6 +3763,19 @@ cuda_py_tests(
],
)
+cc_library(
+ name = "session_ref",
+ srcs = ["client/session_ref.cc"],
+ hdrs = ["client/session_ref.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:replay_log_proto_cc",
+ ],
+)
+
tf_cuda_library(
name = "tf_session_helper",
srcs = ["client/tf_session_helper.cc"],
@@ -3747,6 +3786,7 @@ tf_cuda_library(
":ndarray_tensor_bridge",
":numpy_lib",
":safe_ptr",
+ ":session_ref",
":test_ops_kernels",
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
@@ -3759,7 +3799,6 @@ tf_cuda_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core:session_ref",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
diff --git a/tensorflow/python/autograph/BUILD b/tensorflow/python/autograph/BUILD
new file mode 100644
index 0000000000..3289b447e7
--- /dev/null
+++ b/tensorflow/python/autograph/BUILD
@@ -0,0 +1,31 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "autograph",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/autograph/impl",
+ "//tensorflow/python/autograph/lang",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/python/autograph/CONTRIBUTING.md
index 06fb7b03d5..1ded5ba5f6 100644
--- a/tensorflow/contrib/autograph/CONTRIBUTING.md
+++ b/tensorflow/python/autograph/CONTRIBUTING.md
@@ -2,6 +2,15 @@
We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below.
+### Note to active contributors
+
+In preparation for TF 2.0, we moved the code base of AutoGraph from
+`tensorflow/contrib/autograph` to `tensorflow/python/autograph`. The move
+does not impact functionality, and AutoGraph will remain accessible under
+`tensorflow.contrib.autograph` until `tensorflow.contrib` is retired.
+
+When
+
## TensorFlow Code of Conduct
Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md).
diff --git a/tensorflow/contrib/autograph/LIMITATIONS.md b/tensorflow/python/autograph/LIMITATIONS.md
index d8b1cb7616..d8b1cb7616 100644
--- a/tensorflow/contrib/autograph/LIMITATIONS.md
+++ b/tensorflow/python/autograph/LIMITATIONS.md
diff --git a/tensorflow/python/autograph/README.md b/tensorflow/python/autograph/README.md
new file mode 100644
index 0000000000..bfe21b4765
--- /dev/null
+++ b/tensorflow/python/autograph/README.md
@@ -0,0 +1,143 @@
+# AutoGraph
+
+IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
+
+AutoGraph is a Python to TensorFlow compiler.
+
+With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
+
+For example, this Python function:
+
+```
+def f(x):
+ if x < 0:
+ x = -x
+ return x
+```
+
+would be converted to this:
+
+```
+def graph_mode_f(x):
+ with tf.name_scope('f'):
+
+ def if_true():
+ with tf.name_scope('if_true'):
+ x_1, = x,
+ x_1 = tf.negative(x_1)
+ return x_1,
+
+ def if_false():
+ with tf.name_scope('if_false'):
+ x_1, = x,
+ return x_1,
+ x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
+ return x
+```
+
+so you can use it like an op:
+
+```
+with tf.Graph().as_default():
+ x = tf.constant(-1.0)
+
+ converted_f = autograph.to_graph(f)
+ y = converted_f(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y))
+ # Output: 1
+```
+
+# Getting started
+
+Use AutoGraph in one of the following ways, described below:
+
+ 1. Annotations (simpler)
+ 2. Functional API (more flexible)
+
+To get started, install the latest nightly TensorFlow build:
+
+```shell
+pip install -U tf-nightly
+```
+
+Then import the `autograph` module from `tf.contrib`:
+
+```
+from tensorflow.python import autograph as ag
+```
+
+### Related links
+
+Articles:
+
+ * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
+
+Interactive notebooks:
+
+ * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
+ * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
+ * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb)
+ * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb)
+ * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb)
+ * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb)
+ * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb)
+
+## Using with annotations
+
+Annotating a function or class with `@convert` converts it in place:
+
+```
+@ag.convert()
+def f(x):
+ if x < 0:
+ x = -x
+ return x
+```
+
+... so that it always outputs TensorFlow code:
+
+```
+with tf.Graph().as_default():
+ x = tf.constant(-1)
+
+ y = f(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y))
+ # Output: 1
+```
+
+## Using the functional API
+
+The functional API allows you to convert an existing function, class or object after it was defined:
+
+```
+converted_f = ag.to_graph(f)
+
+print(converted_f(tf.constant(-1)))
+# Output: Tensor
+
+print(f(-1))
+# Output: 1
+```
+
+You can use the functional API to inspect the generated code as well:
+
+```
+print(ag.to_code(f))
+# Output: <Python and TensorFlow code>
+```
+
+## Filing bugs and feature requests
+
+### Reporting a bug
+
+ - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
+
+### Requesting a feature
+
+If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/python/autograph/STYLE_GUIDE.md
index 7e6b0cc27d..7e6b0cc27d 100644
--- a/tensorflow/contrib/autograph/STYLE_GUIDE.md
+++ b/tensorflow/python/autograph/STYLE_GUIDE.md
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py
new file mode 100644
index 0000000000..5ed5e85158
--- /dev/null
+++ b/tensorflow/python/autograph/__init__.py
@@ -0,0 +1,70 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Autograph compiles Python code into equivalent TensorFlow code.
+
+Equivalent here means that they have the same effect when executed.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(mdan): Bring only the relevant symbols to the top level.
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core.errors import GraphConstructionError
+from tensorflow.python.autograph.core.errors import TfRuntimeError
+from tensorflow.python.autograph.core.errors import improved_errors
+from tensorflow.python.autograph.impl.api import ConversionOptions
+from tensorflow.python.autograph.impl.api import RunMode
+from tensorflow.python.autograph.impl.api import convert
+from tensorflow.python.autograph.impl.api import converted_call
+from tensorflow.python.autograph.impl.api import do_not_convert
+from tensorflow.python.autograph.impl.api import to_code
+from tensorflow.python.autograph.impl.api import to_graph
+from tensorflow.python.autograph.lang.directives import set_element_type
+from tensorflow.python.autograph.lang.directives import set_loop_options
+from tensorflow.python.autograph.lang.special_functions import stack
+from tensorflow.python.autograph.lang.special_functions import tensor_list
+from tensorflow.python.autograph.pyct.transformer import AutographParseError
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ # Main API
+ 'ConversionOptions',
+ 'RunMode',
+ 'convert',
+ 'converted_call',
+ 'do_not_convert',
+ 'to_code',
+ 'to_graph',
+ # Overloaded operators
+ 'operators',
+ # Errors
+ 'improved_errors',
+ 'GraphConstructionError',
+ 'TfRuntimeError',
+ # Python language "extensions"
+ 'set_element_type',
+ 'set_loop_options',
+ 'stack',
+ 'tensor_list',
+ # Exceptions
+ 'AutographParseError',
+ # Utilities: to be removed
+ 'utils',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index 2d2ab7040a..7b029de8ed 100644
--- a/tensorflow/contrib/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -38,11 +38,11 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/core",
- "//tensorflow/contrib/autograph/lang",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph/core",
+ "//tensorflow/python/autograph/lang",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
"@gast_archive//:gast",
],
)
@@ -54,8 +54,8 @@ py_test(
tags = ["no_windows"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -65,8 +65,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -77,8 +77,8 @@ py_test(
tags = ["no_windows"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -90,9 +90,9 @@ py_test(
tags = ["no_windows"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/impl",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/impl",
],
)
@@ -102,8 +102,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -113,8 +113,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -124,8 +124,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -139,8 +139,8 @@ py_test(
],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -150,9 +150,9 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/lang",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/lang",
],
)
@@ -161,9 +161,9 @@ py_test(
srcs = ["name_scopes_test.py"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -173,8 +173,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -184,8 +184,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -195,8 +195,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -207,8 +207,8 @@ py_test(
tags = ["notsan"],
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
],
)
@@ -218,9 +218,9 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -230,9 +230,9 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -242,8 +242,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":converters",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
],
)
diff --git a/tensorflow/contrib/autograph/converters/__init__.py b/tensorflow/python/autograph/converters/__init__.py
index 6325ac78dc..6325ac78dc 100644
--- a/tensorflow/contrib/autograph/converters/__init__.py
+++ b/tensorflow/python/autograph/converters/__init__.py
diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/python/autograph/converters/asserts.py
index af2f20f267..56a97534c4 100644
--- a/tensorflow/contrib/autograph/converters/asserts.py
+++ b/tensorflow/python/autograph/converters/asserts.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
class AssertTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py
index 38faba45df..01282f9e62 100644
--- a/tensorflow/contrib/autograph/converters/asserts_test.py
+++ b/tensorflow/python/autograph/converters/asserts_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.converters import asserts
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py
index 180779670d..bd6b0b248c 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/python/autograph/converters/break_statements.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
class _Break(object):
diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py
index fcae7d68c0..39406a969d 100644
--- a/tensorflow/contrib/autograph/converters/break_statements_test.py
+++ b/tensorflow/python/autograph/converters/break_statements_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import break_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.eager import context as tfe_ctx
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py
index 29dce13999..583c978395 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/python/autograph/converters/builtin_functions.py
@@ -20,10 +20,10 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.operators import py_builtins
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
class BuiltinFunctionTransformer(converter.Base):
@@ -48,8 +48,13 @@ class BuiltinFunctionTransformer(converter.Base):
node = self.generic_visit(node)
if anno.hasanno(node.func, 'live_val'):
live_val = anno.getanno(node.func, 'live_val')
- if live_val in py_builtins.SUPPORTED_BUILTINS:
- node = self._convert_builtin(live_val, node.args, as_expression=True)
+ try:
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
+ except TypeError:
+ # Not everything in Python is hashable. If it isn't then it's definitely
+ # not a supported built-in.
+ return node
return node
def visit_Print(self, node):
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/python/autograph/converters/builtin_functions_test.py
index 3e3a04f38b..2ed14c14e7 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/python/autograph/converters/builtin_functions_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import six
-from tensorflow.contrib.autograph.converters import builtin_functions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -36,7 +36,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return len(a)
with self.converted(test_fn, builtin_functions, {'len': len}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
ops = result.test_fn(p)
self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
@@ -50,7 +50,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
with self.assertPrints('a\n'):
sess.run(result.test_fn('a'))
@@ -63,12 +63,22 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a, b, c)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
with self.assertPrints('a 1 [2, 3]\n'):
sess.run(
result.test_fn(
constant_op.constant('a'), constant_op.constant(1), [2, 3]))
+ def test_conversion_robust_to_unhashable_callables(self):
+
+ def test_fn():
+ return foo() # pylint:disable=undefined-variable
+
+ with self.converted(test_fn, builtin_functions, {'foo': {
+ 'a': 'b'
+ }.keys}) as result:
+ self.assertListEqual(list(result.test_fn()), ['a'])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
index 2d1bed3367..fc2075b781 100644
--- a/tensorflow/contrib/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -26,12 +26,12 @@ from collections import namedtuple
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
from tensorflow.python.util import tf_inspect
@@ -238,9 +238,16 @@ class CallTreeTransformer(converter.Base):
# Before we could convert all the time though, we'd need a reasonable
# caching mechanism.
template = """
- ag__.converted_call(func, True, False, False, {}, args)
+ ag__.converted_call(
+ func,
+ ag__.ConversionOptions.new(recursive=recursive_val),
+ args)
"""
- call_expr = templates.replace(template, func=node.func, args=node.args)
+ call_expr = templates.replace(
+ template,
+ func=node.func,
+ recursive_val=parser.parse_expression(str(self.ctx.program.recursive)),
+ args=node.args)
new_call = call_expr[0].value
# TODO(mdan): Improve the template mechanism to better support this.
new_call.keywords = node.keywords
diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py
index ca4d1f2932..0e50f42c6a 100644
--- a/tensorflow/contrib/autograph/converters/call_trees_test.py
+++ b/tensorflow/python/autograph/converters/call_trees_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.autograph.converters import call_trees
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py
index 63f649dfdf..40728f555d 100644
--- a/tensorflow/contrib/autograph/converters/conditional_expressions.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
class _FunctionDefs(object):
diff --git a/tensorflow/contrib/autograph/converters/conditional_expressions_test.py b/tensorflow/python/autograph/converters/conditional_expressions_test.py
index 95a3108741..dd1f8d485c 100644
--- a/tensorflow/contrib/autograph/converters/conditional_expressions_test.py
+++ b/tensorflow/python/autograph/converters/conditional_expressions_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import conditional_expressions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/python/autograph/converters/continue_statements.py
index 0476e97c15..584cdc1efd 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements.py
+++ b/tensorflow/python/autograph/converters/continue_statements.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# Tags for local state.
diff --git a/tensorflow/contrib/autograph/converters/continue_statements_test.py b/tensorflow/python/autograph/converters/continue_statements_test.py
index 37c15211b4..d6aaa50443 100644
--- a/tensorflow/contrib/autograph/converters/continue_statements_test.py
+++ b/tensorflow/python/autograph/converters/continue_statements_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import continue_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.eager import context as tfe_ctx
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index 3530fbb2ec..416a60d2ee 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -20,12 +20,12 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis import annos
class SymbolNamer(object):
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
index 1d04ba3ba6..cfa0ea920c 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import control_flow
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/decorators.py b/tensorflow/python/autograph/converters/decorators.py
index 3471bd11d6..724f0fe5ed 100644
--- a/tensorflow/contrib/autograph/converters/decorators.py
+++ b/tensorflow/python/autograph/converters/decorators.py
@@ -24,8 +24,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/converters/decorators_test.py b/tensorflow/python/autograph/converters/decorators_test.py
index 095abc5edc..fb31c8d583 100644
--- a/tensorflow/contrib/autograph/converters/decorators_test.py
+++ b/tensorflow/python/autograph/converters/decorators_test.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
from functools import wraps
+import imp
-from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python import autograph
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test
@@ -136,6 +138,12 @@ class DecoratorsTest(converter_testing.TestCase):
return inner_fn(a)
+ # Work around TensorFlow's symbol suppression mechanism that causes core to
+ # be invisible in the generated code.
+ core_mod = imp.new_module('core')
+ core_mod.converter_testing = converter_testing
+ autograph.core = core_mod
+
# 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn)
self.assertEqual(14, test_fn(1))
diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/python/autograph/converters/directives.py
index 77f625bac7..fc646348ef 100644
--- a/tensorflow/contrib/autograph/converters/directives.py
+++ b/tensorflow/python/autograph/converters/directives.py
@@ -25,9 +25,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
from tensorflow.python.util import tf_inspect
ENCLOSING_LOOP = 'enclosing_loop'
diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py
index a2d083b891..570fb8e379 100644
--- a/tensorflow/contrib/autograph/converters/directives_test.py
+++ b/tensorflow/python/autograph/converters/directives_test.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import directives as directives_converter
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.core.converter import AgAnno
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.converters import directives as directives_converter
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.core.converter import AgAnno
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/error_handlers.py b/tensorflow/python/autograph/converters/error_handlers.py
index 1936821394..de46c0c830 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers.py
+++ b/tensorflow/python/autograph/converters/error_handlers.py
@@ -22,9 +22,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
class ErrorRewritingTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/error_handlers_test.py b/tensorflow/python/autograph/converters/error_handlers_test.py
index 5d61b220af..676ff9e02b 100644
--- a/tensorflow/contrib/autograph/converters/error_handlers_test.py
+++ b/tensorflow/python/autograph/converters/error_handlers_test.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import error_handlers
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/list_comprehensions.py b/tensorflow/python/autograph/converters/list_comprehensions.py
index ecf4628816..5be6cb9a98 100644
--- a/tensorflow/contrib/autograph/converters/list_comprehensions.py
+++ b/tensorflow/python/autograph/converters/list_comprehensions.py
@@ -32,8 +32,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
# TODO(mdan): This should covert directly to operator calls.
diff --git a/tensorflow/contrib/autograph/converters/list_comprehensions_test.py b/tensorflow/python/autograph/converters/list_comprehensions_test.py
index 59b5ce9ca0..1e66139af6 100644
--- a/tensorflow/contrib/autograph/converters/list_comprehensions_test.py
+++ b/tensorflow/python/autograph/converters/list_comprehensions_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import list_comprehensions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import list_comprehensions
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/python/autograph/converters/lists.py
index a02fc827b8..8180801753 100644
--- a/tensorflow/contrib/autograph/converters/lists.py
+++ b/tensorflow/python/autograph/converters/lists.py
@@ -32,12 +32,12 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# Tags for local state.
diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py
index c5e2dcf75e..f6da845fcc 100644
--- a/tensorflow/contrib/autograph/converters/lists_test.py
+++ b/tensorflow/python/autograph/converters/lists_test.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import lists
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.lang import special_functions
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.lang import special_functions
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py
index 41c3424fa3..8c4d53f9a8 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions.py
+++ b/tensorflow/python/autograph/converters/logical_expressions.py
@@ -23,10 +23,10 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
# TODO(mdan): Properly extrack boolean ops according to lazy eval rules.
@@ -57,8 +57,6 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'ag__.utils.dynamic_is',
- gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
@@ -72,12 +70,13 @@ class LogicalExpressionTransformer(converter.Base):
'"a.x or b"; for a workaround, assign the expression to a local '
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
+ def _has_matching_func(self, operator):
+ op_type = type(operator)
+ return op_type in self.op_mapping
+
def _matching_func(self, operator):
op_type = type(operator)
- mapped_op = self.op_mapping.get(op_type)
- if not mapped_op:
- raise NotImplementedError('operator %s is not yet supported' % op_type)
- return mapped_op
+ return self.op_mapping[op_type]
def _as_function(self, func_name, args):
template = """
@@ -90,6 +89,16 @@ class LogicalExpressionTransformer(converter.Base):
def visit_Compare(self, node):
node = self.generic_visit(node)
+
+ if not all(self._has_matching_func(op) for op in node.ops):
+ if len(node.ops) == 1:
+ # Basic expressions are safe to leave as they are.
+ return node
+ else:
+ raise NotImplementedError(
+ 'compound expression with at least one unsupported '
+ 'operator: {}'.format(node.ops))
+
ops_and_comps = list(zip(node.ops, node.comparators))
left = node.left
op_tree = None
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py
index 409a73afba..b78b4d3a6a 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/python/autograph/converters/logical_expressions_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import logical_expressions
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -47,14 +47,12 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(True, False, True)))
- def test_ag_utils_lookup(self):
+ def test_unsupported_ops(self):
def test_fn(a, b):
- return a is b or a is not b
+ return a in b
- with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
- ) as result:
- with self.cached_session() as sess:
- self.assertTrue(sess.run(result.test_fn(True, False)))
+ with self.converted(test_fn, logical_expressions, {}) as result:
+ self.assertTrue(result.test_fn('a', ('a',)))
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/converters/name_scopes.py b/tensorflow/python/autograph/converters/name_scopes.py
index dd6c6bf960..a9c55ccff0 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes.py
+++ b/tensorflow/python/autograph/converters/name_scopes.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
class FunctionNameScopeTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/name_scopes_test.py b/tensorflow/python/autograph/converters/name_scopes_test.py
index a329b0db70..73933c1c4f 100644
--- a/tensorflow/contrib/autograph/converters/name_scopes_test.py
+++ b/tensorflow/python/autograph/converters/name_scopes_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import name_scopes
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
index a351cd81b8..62da045d6a 100644
--- a/tensorflow/contrib/autograph/converters/return_statements.py
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# TODO(mdan): Move this logic into transformer_base.
diff --git a/tensorflow/contrib/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py
index 3c7c8c8a25..01dd03da0b 100644
--- a/tensorflow/contrib/autograph/converters/return_statements_test.py
+++ b/tensorflow/python/autograph/converters/return_statements_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import return_statements
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py
index b808604f0a..6e48e57bde 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards.py
@@ -36,12 +36,12 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
class SymbolNamer(object):
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/python/autograph/converters/side_effect_guards_test.py
index 5fe5114d4b..cef3199169 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import side_effect_guards
-from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/python/autograph/converters/slices.py
index c527f98613..11cea6de5b 100644
--- a/tensorflow/contrib/autograph/converters/slices.py
+++ b/tensorflow/python/autograph/converters/slices.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import templates
class SliceTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/python/autograph/converters/slices_test.py
index d74b2e025e..e190a7cfe8 100644
--- a/tensorflow/contrib/autograph/converters/slices_test.py
+++ b/tensorflow/python/autograph/converters/slices_test.py
@@ -18,12 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.converters import slices
-from tensorflow.contrib.autograph.core import converter_testing
-from tensorflow.contrib.autograph.lang import directives
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import list_ops
diff --git a/tensorflow/contrib/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
index 1873045a92..85fecf084d 100644
--- a/tensorflow/contrib/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -25,9 +25,9 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis",
- "//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
],
)
@@ -65,10 +65,10 @@ py_library(
visibility = ["//tensorflow:__subpackages__"],
deps = [
":core",
- "//tensorflow/contrib/autograph/operators",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis",
- "//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
"@six_archive//:six",
],
diff --git a/tensorflow/contrib/autograph/core/config.py b/tensorflow/python/autograph/core/config.py
index 878bb7e12f..4fa8489af5 100644
--- a/tensorflow/contrib/autograph/core/config.py
+++ b/tensorflow/python/autograph/core/config.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph import utils
+from tensorflow.python.autograph import utils
PYTHON_LITERALS = {
@@ -36,7 +36,7 @@ DEFAULT_UNCOMPILED_MODULES = set((
# have well-known names. Not referring to the module directly to avoid
# circular imports.
(
- utils.__name__[:-len('.contrib.autograph.utils')],),
+ utils.__name__[:-len('.python.autograph.utils')],),
))
NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 83a80c1f52..80928ae7f4 100644
--- a/tensorflow/contrib/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -63,23 +63,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
from enum import Enum
-
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import naming
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import liveness
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import naming
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import liveness
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
# TODO(mdan): These contexts can be refactored into first class objects.
# For example, we could define Program and Entity abstractions that hold on
@@ -129,9 +127,8 @@ class ProgramContext(object):
self.autograph_module = autograph_module
self.uncompiled_modules = uncompiled_modules
- # Required to output dependencies in discovery order, which should match
- # the reverse dependency order.
- self.dependency_cache = collections.OrderedDict()
+ self.conversion_order = []
+ self.dependency_cache = {}
self.additional_imports = set()
self.name_map = {}
@@ -177,6 +174,7 @@ class ProgramContext(object):
self.name_map[o] = name
def add_to_cache(self, original_entity, converted_ast):
+ self.conversion_order.append(original_entity)
self.dependency_cache[original_entity] = converted_ast
diff --git a/tensorflow/contrib/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 5ee2c3fffd..7ce1b7c4c5 100644
--- a/tensorflow/contrib/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -24,15 +24,15 @@ import sys
import six
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import pretty_printer
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test
@@ -93,11 +93,21 @@ class TestCase(test.TestCase):
self.dynamic_calls.append(args)
return 7
+ class ConversionOptions(object):
+ """Mock version of api.ConversionOptions."""
+
+ def __init__(self, recursive):
+ self.recursive = recursive
+
+ @classmethod
+ def new(cls, recursive):
+ cls(recursive)
+
try:
result, source = compiler.ast_to_object(node, include_source_map=True)
result.tf = self.make_fake_mod('fake_tf', *symbols)
- fake_ag = self.make_fake_mod('fake_ag', converted_call)
+ fake_ag = self.make_fake_mod('fake_ag', converted_call, ConversionOptions)
fake_ag.__dict__.update(operators.__dict__)
fake_ag.__dict__['utils'] = utils
fake_ag.__dict__['rewrite_graph_construction_error'] = (
diff --git a/tensorflow/contrib/autograph/core/errors.py b/tensorflow/python/autograph/core/errors.py
index 5a57d57e7d..23f8c5b52b 100644
--- a/tensorflow/contrib/autograph/core/errors.py
+++ b/tensorflow/python/autograph/core/errors.py
@@ -31,7 +31,7 @@ import logging
import sys
import traceback
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.framework import errors_impl
# TODO(mdan): Add a superclass common to all errors.
@@ -208,7 +208,6 @@ def rewrite_tf_runtime_error(error, source_map):
"""
try:
cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback)
- # cleaned_traceback = error.op.traceback
cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
op_name = error.op.name
diff --git a/tensorflow/contrib/autograph/core/errors_test.py b/tensorflow/python/autograph/core/errors_test.py
index 404c1f5456..aa6c293268 100644
--- a/tensorflow/contrib/autograph/core/errors_test.py
+++ b/tensorflow/python/autograph/core/errors_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors as tf_errors
from tensorflow.python.ops import array_ops
@@ -54,7 +54,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(ops)
for frame in cm.exception.custom_traceback:
@@ -69,7 +69,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller()
with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(ops)
all_function_names = set()
@@ -86,7 +86,7 @@ class RuntimeErrorsTest(test.TestCase):
ops = zero_div_caller()
with self.assertRaises(tf_errors.InvalidArgumentError):
with errors.improved_errors(zero_div_caller):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(ops)
def test_improved_errors_validation(self):
diff --git a/tensorflow/contrib/autograph/core/naming.py b/tensorflow/python/autograph/core/naming.py
index b1d3f76be7..aecc9e33ca 100644
--- a/tensorflow/contrib/autograph/core/naming.py
+++ b/tensorflow/python/autograph/core/naming.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import qual_names
class Namer(object):
diff --git a/tensorflow/contrib/autograph/core/naming_test.py b/tensorflow/python/autograph/core/naming_test.py
index d2bebd0478..2db98836d1 100644
--- a/tensorflow/contrib/autograph/core/naming_test.py
+++ b/tensorflow/python/autograph/core/naming_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.core import naming
+from tensorflow.python.autograph.core import naming
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
index c2427f5f4f..c2427f5f4f 100644
--- a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
+++ b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD
index a5438592c3..bef62a6403 100644
--- a/tensorflow/contrib/autograph/impl/BUILD
+++ b/tensorflow/python/autograph/impl/BUILD
@@ -23,14 +23,14 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/converters",
- "//tensorflow/contrib/autograph/core",
- "//tensorflow/contrib/autograph/operators",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/pyct/static_analysis",
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph/converters",
+ "//tensorflow/python/autograph/core",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
"@six_archive//:six",
],
@@ -43,8 +43,8 @@ py_test(
tags = ["no_windows"],
deps = [
":impl",
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/utils",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index 8b38d5d080..1dc97d2331 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -18,21 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from functools import wraps
+import collections
+import functools
from enum import Enum
-# pylint:disable=g-bad-import-order
-import six
-# pylint:enable=g-bad-import-order
-
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.impl import conversion
-from tensorflow.contrib.autograph.operators import py_builtins
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.impl import conversion
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -42,6 +39,41 @@ from tensorflow.python.util import tf_inspect
# (currently we require (module + class name, type))
+class ConversionOptions(
+ collections.namedtuple('ConversionOptions',
+ ('recursive', 'verbose', 'strip_decorators',
+ 'force_conversion', 'arg_types'))):
+ """Container for conversion flags.
+
+ Attributes:
+ recursive: bool, whether to recursively convert any user functions or
+ classes that the converted function may use.
+ verbose: bool, whether to log the compiled code.
+ strip_decorators: Tuple[Callable], contains decorators that should be in
+ excluded from the compiled output. By default, when converting a
+ function before the decorators are applied, the compiled output will
+ include those decorators.
+ force_conversion: bool, whether to force convertinng the target entity.
+ When force_conversion is turned off, the converter may decide to
+ return the function as-is.
+ arg_types: Optional[Dict[Text, Type]], type hints for symbols including
+ function arguments.
+ """
+
+ @classmethod
+ def new(cls,
+ recursive=False,
+ verbose=False,
+ strip_decorators=None,
+ force_conversion=False,
+ arg_types=None):
+ return cls(recursive=recursive,
+ verbose=verbose,
+ strip_decorators=strip_decorators or (),
+ force_conversion=force_conversion,
+ arg_types=arg_types or {})
+
+
# TODO(mdan): This should behave like to_graph (e.g. convert statically).
def convert(recursive=False, verbose=False):
"""Decorator that compiles a function to use TensorFlow ops.
@@ -63,9 +95,15 @@ def convert(recursive=False, verbose=False):
def decorator(f):
"""Decorator implementation."""
- @wraps(f)
+ @functools.wraps(f)
def wrapper(*args, **kwargs):
- return converted_call(f, recursive, verbose, True, {}, *args, **kwargs)
+ return converted_call(
+ f,
+ ConversionOptions.new(
+ recursive=recursive,
+ verbose=verbose,
+ force_conversion=True,
+ ), *args, **kwargs)
wrapper = tf_decorator.make_decorator(f, wrapper)
@@ -111,11 +149,11 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
def decorator(f):
"""Decorator implementation."""
- @wraps(f)
+ @functools.wraps(f)
def graph_wrapper(*args, **kwargs):
return f(*args, **kwargs)
- @wraps(f)
+ @functools.wraps(f)
def py_func_wrapper(*args, **kwargs):
if kwargs:
raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
@@ -139,12 +177,11 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
# TODO(mdan): Move to a private, undocumented module.
-def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
- **kwargs):
+def converted_call(f, options, *args, **kwargs):
"""Compiles a function call inline. For internal use only."""
# TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
- if not force_conversion and conversion.is_whitelisted_for_graph(f):
+ if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
return f(*args, **kwargs)
unknown_arg_value = object() # Sentinel for arguments of unknown value
@@ -187,8 +224,8 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
continue
arg_class = arg.__class__
# If arg_value_hints specifies any name, use that instead.
- if name not in arg_types:
- arg_types[name] = (arg_class.__name__, arg_class)
+ if name not in options.arg_types:
+ options.arg_types[name] = (arg_class.__name__, arg_class)
# When called from within a decorator, this is the only indication that
# the function is a method - it appears that the decorator is applied
@@ -203,23 +240,25 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
converted_f = to_graph(
target_entity,
- recursive=recursive,
- verbose=verbose,
+ recursive=options.recursive,
+ verbose=options.verbose,
arg_values=arg_values,
- arg_types=arg_types,
- partial_types=partial_types)
+ arg_types=options.arg_types,
+ partial_types=partial_types,
+ strip_decorators=options.strip_decorators)
return converted_f(*effective_args, **kwargs)
# TODO(mdan): Rename: to_ops?
-# TODO(mdan): Looki into overloading as function and decorator, like tfe.defun.
+# TODO(mdan): Look into overloading as function and decorator, like tfe.defun?
# TODO(mdan): Remove partial_types.
def to_graph(e,
recursive=True,
verbose=False,
arg_values=None,
arg_types=None,
- partial_types=None):
+ partial_types=None,
+ strip_decorators=None):
"""Converts a Python entity into equivalent code that uses TensorFlow ops.
Supported Python entities include:
@@ -238,6 +277,8 @@ def to_graph(e,
arg_types: Optional[Dict[Text, Type]], type hints for symbols including
function arguments.
partial_types: Set[Type], reserved for internal use.
+ strip_decorators: Tuple[Callable], same as
+ ConversionOptions.strip_decorators.
Returns:
Union[Callable, Type], the converted entity, which is the same kind as e
@@ -247,9 +288,13 @@ def to_graph(e,
Raises:
ValueError: If the entity could not be converted.
"""
+ if strip_decorators is None:
+ strip_decorators = ()
+ strip_decorators += (convert, do_not_convert, converted_call)
+
program_ctx = converter.ProgramContext(
recursive=recursive,
- autograph_decorators=(convert, do_not_convert, converted_call),
+ autograph_decorators=strip_decorators,
partial_types=partial_types,
autograph_module=tf_inspect.getmodule(to_graph),
uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
@@ -257,8 +302,9 @@ def to_graph(e,
arg_types)
nodes = []
- for dep in reversed(program_ctx.dependency_cache.values()):
- nodes.extend(dep)
+ for dep in reversed(program_ctx.conversion_order):
+ nodes.extend(program_ctx.dependency_cache[dep])
+
compiled_module, compiled_src = compiler.ast_to_object(
nodes,
source_prefix=program_ctx.required_imports,
@@ -326,7 +372,7 @@ def to_code(e,
conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
code = '\n'.join(
- compiler.ast_to_source(dep, indentation)
- for dep in reversed(tuple(six.itervalues(program_ctx.dependency_cache))))
+ compiler.ast_to_source(program_ctx.dependency_cache[dep], indentation)
+ for dep in reversed(program_ctx.conversion_order))
return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index a4c6fed265..8ce5022c0a 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.impl import api
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -32,7 +32,6 @@ from tensorflow.python.util import tf_inspect
tf = utils.fake_tf()
-
class ApiTest(test.TestCase):
def setUp(self):
@@ -56,7 +55,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -76,7 +75,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -97,7 +96,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -123,7 +122,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -146,7 +145,7 @@ class ApiTest(test.TestCase):
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
@@ -180,19 +179,20 @@ class ApiTest(test.TestCase):
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
- x //= api.converted_call(self.called_member, False, False, False, {},
- self, a)
+ x //= api.converted_call(
+ self.called_member,
+ api.ConversionOptions.new(), self, a)
return x
tc = TestClass()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = tc.test_method(
constant_op.constant([2, 4]), constant_op.constant(1),
constant_op.constant(-2))
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_converted_call_builtin(self):
- x = api.converted_call(range, False, False, False, {}, 3)
+ x = api.converted_call(range, api.ConversionOptions.new(), 3)
self.assertEqual((0, 1, 2), tuple(x))
def test_converted_call_function(self):
@@ -202,8 +202,8 @@ class ApiTest(test.TestCase):
return -x
return x
- with self.test_session() as sess:
- x = api.converted_call(test_fn, False, False, False, {},
+ with self.cached_session() as sess:
+ x = api.converted_call(test_fn, api.ConversionOptions.new(),
constant_op.constant(-1))
self.assertEqual(1, sess.run(x))
@@ -219,9 +219,9 @@ class ApiTest(test.TestCase):
return -self.x
return self.x
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc.test_method, False, False, False, {}, tc)
+ x = api.converted_call(tc.test_method, api.ConversionOptions.new(), tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_method_by_class(self):
@@ -236,9 +236,11 @@ class ApiTest(test.TestCase):
return -self.x
return self.x
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(TestClass.test_method, False, False, False, {}, tc)
+ x = api.converted_call(
+ TestClass.test_method,
+ api.ConversionOptions.new(), tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_callable_object(self):
@@ -253,9 +255,9 @@ class ApiTest(test.TestCase):
return -self.x
return self.x
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc, False, False, False, {})
+ x = api.converted_call(tc, api.ConversionOptions.new())
self.assertEqual(1, sess.run(x))
def test_converted_call_constructor(self):
@@ -270,8 +272,8 @@ class ApiTest(test.TestCase):
return -self.x
return self.x
- with self.test_session() as sess:
- tc = api.converted_call(TestClass, False, False, False, {},
+ with self.cached_session() as sess:
+ tc = api.converted_call(TestClass, api.ConversionOptions.new(),
constant_op.constant(-1))
# tc is now a converted object.
x = tc.test_method()
@@ -282,13 +284,13 @@ class ApiTest(test.TestCase):
def f(x):
return x == 0
- with self.test_session() as sess:
- x = api.converted_call(f, False, False, False, {},
+ with self.cached_session() as sess:
+ x = api.converted_call(f, api.ConversionOptions.new(),
constant_op.constant(0))
self.assertTrue(sess.run(x))
converted_f = api.to_graph(f)
- x = api.converted_call(converted_f, False, False, False, {},
+ x = api.converted_call(converted_f, api.ConversionOptions.new(),
constant_op.constant(0))
self.assertTrue(sess.run(x))
@@ -301,7 +303,7 @@ class ApiTest(test.TestCase):
compiled_fn = api.to_graph(test_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = compiled_fn(constant_op.constant([4, 8]), 4)
self.assertListEqual([1, 2], sess.run(x).tolist())
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index fc8a976d3f..a0d13c82a8 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -22,34 +22,34 @@ import imp
import gast
-from tensorflow.contrib.autograph import operators
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.converters import asserts
-from tensorflow.contrib.autograph.converters import break_statements
-from tensorflow.contrib.autograph.converters import builtin_functions
-from tensorflow.contrib.autograph.converters import call_trees
-from tensorflow.contrib.autograph.converters import conditional_expressions
-from tensorflow.contrib.autograph.converters import continue_statements
-from tensorflow.contrib.autograph.converters import control_flow
-from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.converters import directives
-from tensorflow.contrib.autograph.converters import error_handlers
-from tensorflow.contrib.autograph.converters import lists
-from tensorflow.contrib.autograph.converters import logical_expressions
-from tensorflow.contrib.autograph.converters import name_scopes
-from tensorflow.contrib.autograph.converters import return_statements
-from tensorflow.contrib.autograph.converters import side_effect_guards
-from tensorflow.contrib.autograph.converters import slices
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.core import errors
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.pyct import origin_info
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.converters import directives
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -255,6 +255,7 @@ def _add_self_references(namespace, autograph_module):
# internal modules.
ag_internal = imp.new_module('autograph')
ag_internal.converted_call = autograph_module.converted_call
+ ag_internal.ConversionOptions = autograph_module.ConversionOptions
ag_internal.utils = utils
ag_internal.rewrite_graph_construction_error = (
errors.rewrite_graph_construction_error)
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py
index 86432573a7..07d0f75129 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/python/autograph/impl/conversion_test.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph import utils
-from tensorflow.contrib.autograph.core import config
-from tensorflow.contrib.autograph.core import converter
-from tensorflow.contrib.autograph.impl import api
-from tensorflow.contrib.autograph.impl import conversion
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.impl import conversion
from tensorflow.python.framework import constant_op
from tensorflow.python.keras.engine import training
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/lang/BUILD b/tensorflow/python/autograph/lang/BUILD
index 77a2184e22..462349cc10 100644
--- a/tensorflow/contrib/autograph/lang/BUILD
+++ b/tensorflow/python/autograph/lang/BUILD
@@ -25,7 +25,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/operators",
+ "//tensorflow/python/autograph/operators",
],
)
diff --git a/tensorflow/contrib/autograph/lang/directives.py b/tensorflow/python/autograph/lang/directives.py
index aabe5d9939..aabe5d9939 100644
--- a/tensorflow/contrib/autograph/lang/directives.py
+++ b/tensorflow/python/autograph/lang/directives.py
diff --git a/tensorflow/contrib/autograph/lang/special_functions.py b/tensorflow/python/autograph/lang/special_functions.py
index 6149cbbd6c..e4838d1b6d 100644
--- a/tensorflow/contrib/autograph/lang/special_functions.py
+++ b/tensorflow/python/autograph/lang/special_functions.py
@@ -23,7 +23,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import data_structures
def tensor_list(elements,
diff --git a/tensorflow/contrib/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
index db492cc5c6..545dd11729 100644
--- a/tensorflow/contrib/autograph/lang/special_functions_test.py
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.lang import special_functions
+from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
@@ -33,7 +33,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
@@ -41,7 +41,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self):
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD
index 29759bad79..a116611b64 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/python/autograph/operators/BUILD
@@ -28,7 +28,6 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
@@ -38,6 +37,7 @@ py_library(
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variables",
+ "//tensorflow/python/autograph/utils",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -66,6 +66,7 @@ py_test(
name = "py_builtins_test",
srcs = ["py_builtins_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_windows"],
deps = [
":operators",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py
index c4fbc260a2..0d3b44b6c4 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/python/autograph/operators/__init__.py
@@ -37,19 +37,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators.control_flow import for_stmt
-from tensorflow.contrib.autograph.operators.control_flow import while_stmt
-from tensorflow.contrib.autograph.operators.data_structures import list_append
-from tensorflow.contrib.autograph.operators.data_structures import list_pop
-from tensorflow.contrib.autograph.operators.data_structures import list_stack
-from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
-from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
-from tensorflow.contrib.autograph.operators.data_structures import new_list
-from tensorflow.contrib.autograph.operators.py_builtins import float_
-from tensorflow.contrib.autograph.operators.py_builtins import int_
-from tensorflow.contrib.autograph.operators.py_builtins import len_
-from tensorflow.contrib.autograph.operators.py_builtins import print_
-from tensorflow.contrib.autograph.operators.py_builtins import range_
-from tensorflow.contrib.autograph.operators.slices import get_item
-from tensorflow.contrib.autograph.operators.slices import GetItemOpts
-from tensorflow.contrib.autograph.operators.slices import set_item
+from tensorflow.python.autograph.operators.control_flow import for_stmt
+from tensorflow.python.autograph.operators.control_flow import while_stmt
+from tensorflow.python.autograph.operators.data_structures import list_append
+from tensorflow.python.autograph.operators.data_structures import list_pop
+from tensorflow.python.autograph.operators.data_structures import list_stack
+from tensorflow.python.autograph.operators.data_structures import ListPopOpts
+from tensorflow.python.autograph.operators.data_structures import ListStackOpts
+from tensorflow.python.autograph.operators.data_structures import new_list
+from tensorflow.python.autograph.operators.py_builtins import float_
+from tensorflow.python.autograph.operators.py_builtins import int_
+from tensorflow.python.autograph.operators.py_builtins import len_
+from tensorflow.python.autograph.operators.py_builtins import print_
+from tensorflow.python.autograph.operators.py_builtins import range_
+from tensorflow.python.autograph.operators.slices import get_item
+from tensorflow.python.autograph.operators.slices import GetItemOpts
+from tensorflow.python.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
index 9a66a6bb60..6eedd695a7 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py
index 677b7f8f62..bb214b6f16 100644
--- a/tensorflow/contrib/autograph/operators/control_flow_test.py
+++ b/tensorflow/python/autograph/operators/control_flow_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import control_flow
+from tensorflow.python.autograph.operators import control_flow
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/python/autograph/operators/data_structures.py
index cc0a3c3544..cc0a3c3544 100644
--- a/tensorflow/contrib/autograph/operators/data_structures.py
+++ b/tensorflow/python/autograph/operators/data_structures.py
diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
index 4b1e835d44..8532dbe466 100644
--- a/tensorflow/contrib/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import data_structures
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/autograph/operators/dispatch_context.py b/tensorflow/python/autograph/operators/dispatch_context.py
index 097002465b..097002465b 100644
--- a/tensorflow/contrib/autograph/operators/dispatch_context.py
+++ b/tensorflow/python/autograph/operators/dispatch_context.py
diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index c5730934e7..91a2a22cc2 100644
--- a/tensorflow/contrib/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -23,8 +23,8 @@ from __future__ import print_function
import six
-from tensorflow.contrib.autograph.utils import py_func
-from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.autograph.utils import py_func
+from tensorflow.python.autograph.utils import tensors
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -193,11 +193,18 @@ def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
def _tf_range(start_or_stop, stop, step):
+ # Note: for static inputs (e.g. constants), tf.range errors out at graph
+ # construction time, instead of returning an empty tensor. Preventing the
+ # graph construction error aligns the semantics with Python.
+
# TODO(mdan): We should optimize this when a full tensor is not required.
if step is not UNDEFINED:
+ # TODO(mdan): Add argument coercion similar to other cases.
return math_ops.range(start_or_stop, stop, step)
if stop is not UNDEFINED:
+ stop = math_ops.maximum(start_or_stop, stop)
return math_ops.range(start_or_stop, stop)
+ start_or_stop = math_ops.maximum(start_or_stop, 0)
return math_ops.range(start_or_stop)
diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
index 4073c51785..c94a918d5a 100644
--- a/tensorflow/contrib/autograph/operators/py_builtins_test.py
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -22,8 +22,8 @@ import sys
import six
-from tensorflow.contrib.autograph.operators import data_structures
-from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.python.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -36,7 +36,7 @@ class PyBuiltinsTest(test.TestCase):
def test_abs(self):
self.assertEqual(py_builtins.abs_(-1), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.abs_(constant_op.constant(-1))
self.assertEqual(sess.run(t), 1)
t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
@@ -45,7 +45,7 @@ class PyBuiltinsTest(test.TestCase):
def test_float(self):
self.assertEqual(py_builtins.float_(10), 10.0)
self.assertEqual(py_builtins.float_('10.0'), 10.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
self.assertEqual(sess.run(t), 1.0)
st = py_builtins.float_(constant_op.constant('1.0'))
@@ -54,7 +54,7 @@ class PyBuiltinsTest(test.TestCase):
def test_int(self):
self.assertEqual(py_builtins.int_(10.0), 10)
self.assertEqual(py_builtins.int_('11', 2), 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
self.assertEqual(sess.run(t), 1)
st = py_builtins.int_(constant_op.constant('1'))
@@ -69,7 +69,7 @@ class PyBuiltinsTest(test.TestCase):
def test_len(self):
self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
self.assertEqual(t, 3)
ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
@@ -82,7 +82,7 @@ class PyBuiltinsTest(test.TestCase):
py_builtins.len_(constant_op.constant(1))
def test_len_dynamic_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
t = py_builtins.len_(p)
self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
@@ -95,7 +95,7 @@ class PyBuiltinsTest(test.TestCase):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
finally:
@@ -105,7 +105,7 @@ class PyBuiltinsTest(test.TestCase):
try:
out_capturer = six.StringIO()
sys.stdout = out_capturer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
py_builtins.print_(constant_op.constant('test message'), [1, 2]))
self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
@@ -118,7 +118,7 @@ class PyBuiltinsTest(test.TestCase):
self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
def test_range_tensor(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
r = py_builtins.range_(constant_op.constant(3))
self.assertAllEqual(sess.run(r), [0, 1, 2])
r = py_builtins.range_(1, constant_op.constant(3))
@@ -126,6 +126,13 @@ class PyBuiltinsTest(test.TestCase):
r = py_builtins.range_(2, 0, constant_op.constant(-1))
self.assertAllEqual(sess.run(r), [2, 1])
+ def test_range_tensor_empty_range(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(-3))
+ self.assertAllEqual(sess.run(r), [])
+ r = py_builtins.range_(5, constant_op.constant(2))
+ self.assertAllEqual(sess.run(r), [])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/python/autograph/operators/slices.py
index 2b7f5ad922..2b7f5ad922 100644
--- a/tensorflow/contrib/autograph/operators/slices.py
+++ b/tensorflow/python/autograph/operators/slices.py
diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/python/autograph/operators/slices_test.py
index 5255b7e2b6..9e4865b3c6 100644
--- a/tensorflow/contrib/autograph/operators/slices_test.py
+++ b/tensorflow/python/autograph/operators/slices_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.operators import slices
+from tensorflow.python.autograph.operators import slices
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import list_ops
from tensorflow.python.platform import test
@@ -51,14 +51,14 @@ class SlicesTest(test.TestCase):
t = slices.get_item(initial_str, 1,
slices.GetItemOpts(element_dtype=initial_str.dtype))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(t), b'b')
initial_list_str = constant_op.constant(['abcd', 'bcde'])
t = slices.get_item(initial_list_str, 1,
slices.GetItemOpts(element_dtype=initial_str.dtype))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(t), b'bcde')
diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD
index ddadc6b96e..ddadc6b96e 100644
--- a/tensorflow/contrib/autograph/pyct/BUILD
+++ b/tensorflow/python/autograph/pyct/BUILD
diff --git a/tensorflow/contrib/autograph/pyct/__init__.py b/tensorflow/python/autograph/pyct/__init__.py
index d787e56bbe..d787e56bbe 100644
--- a/tensorflow/contrib/autograph/pyct/__init__.py
+++ b/tensorflow/python/autograph/pyct/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
index 1a52110ef3..1a52110ef3 100644
--- a/tensorflow/contrib/autograph/pyct/anno.py
+++ b/tensorflow/python/autograph/pyct/anno.py
diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/python/autograph/pyct/anno_test.py
index 5ef4da61a3..1f873871c6 100644
--- a/tensorflow/contrib/autograph/pyct/anno_test.py
+++ b/tensorflow/python/autograph/pyct/anno_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import ast
-from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import anno
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/python/autograph/pyct/ast_util.py
index d7453b0781..7df3b8858c 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util.py
+++ b/tensorflow/python/autograph/pyct/ast_util.py
@@ -22,8 +22,8 @@ import ast
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
class CleanCopier(object):
diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/python/autograph/pyct/ast_util_test.py
index 2293c89720..b1577c466e 100644
--- a/tensorflow/contrib/autograph/pyct/ast_util_test.py
+++ b/tensorflow/python/autograph/pyct/ast_util_test.py
@@ -22,11 +22,11 @@ import ast
import collections
import textwrap
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index ba51dcf285..fca0eb62e4 100644
--- a/tensorflow/contrib/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -27,13 +27,14 @@ from __future__ import division
from __future__ import print_function
import collections
+import weakref
from enum import Enum
# pylint:disable=g-bad-import-order
import gast
# pylint:enable=g-bad-import-order
-from tensorflow.contrib.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import compiler
class Node(object):
@@ -61,7 +62,10 @@ class Node(object):
def freeze(self):
self.next = frozenset(self.next)
- self.prev = frozenset(self.prev)
+ # Assumption: All CFG nodes have identical life spans, because the graph
+ # owns them. Nodes should never be used outside the context of an existing
+ # graph.
+ self.prev = weakref.WeakSet(self.prev)
def __repr__(self):
if isinstance(self.ast_node, gast.FunctionDef):
@@ -256,7 +260,7 @@ class GraphBuilder(object):
"""Resets the state of this factory."""
self.head = None
self.errors = set()
- self.node_index = collections.OrderedDict()
+ self.node_index = {}
# TODO(mdan): Too many primitives. Use classes.
self.leaves = set()
@@ -309,7 +313,10 @@ class GraphBuilder(object):
"""Grows the graph by adding a CFG node following the current leaves."""
if ast_node is self.node_index:
raise ValueError('%s added twice' % ast_node)
- node = Node(next_=set(), prev=set(), ast_node=ast_node)
+ # Assumption: All CFG nodes have identical life spans, because the graph
+ # owns them. Nodes should never be used outside the context of an existing
+ # graph.
+ node = Node(next_=set(), prev=weakref.WeakSet(), ast_node=ast_node)
self.node_index[ast_node] = node
self.owners[node] = frozenset(self.active_stmts)
diff --git a/tensorflow/contrib/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py
index 9d0a85d615..bd82e70f7d 100644
--- a/tensorflow/contrib/autograph/pyct/cfg_test.py
+++ b/tensorflow/python/autograph/pyct/cfg_test.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/python/autograph/pyct/common_transformers/BUILD
index fe630ef852..5e2f8f3ac0 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
+++ b/tensorflow/python/autograph/pyct/common_transformers/BUILD
@@ -26,7 +26,7 @@ py_library(
"@six_archive//:six",
# TODO(aqj) Revisit this dependency direction when pyct is more
# modularized
- "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python/autograph/pyct",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py b/tensorflow/python/autograph/pyct/common_transformers/__init__.py
index e69de29bb2..e69de29bb2 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/python/autograph/pyct/common_transformers/anf.py
index d77c15915b..192621b1cd 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf.py
@@ -29,8 +29,8 @@ from __future__ import print_function
import gast
import six
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
class DummyGensym(object):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
index 1ffd4bbe55..ccc7e4ca8f 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
@@ -20,10 +20,10 @@ from __future__ import print_function
import textwrap
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.common_transformers import anf
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.common_transformers import anf
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py
index f9cee10962..21281aeb56 100644
--- a/tensorflow/contrib/autograph/pyct/compiler.py
+++ b/tensorflow/python/autograph/pyct/compiler.py
@@ -30,7 +30,7 @@ import tempfile
import astor
import gast
-from tensorflow.contrib.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import origin_info
def ast_to_source(node, indentation=' '):
@@ -57,8 +57,15 @@ def ast_to_source(node, indentation=' '):
# In some versions of Python, literals may appear as actual values. This
# ensures everything is string.
- code = map(str, generator.result)
- code = astor.source_repr.pretty_source(code).lstrip()
+ code = ''.join(map(str, generator.result))
+
+ # Strip leading blank lines.
+ code_lines = code.split('\n')
+ trimmed_code_lines = []
+ for l in code_lines:
+ if l.rstrip() or trimmed_code_lines:
+ trimmed_code_lines.append(l)
+ code = '\n'.join(trimmed_code_lines)
return code
@@ -108,7 +115,7 @@ def ast_to_object(nodes,
indices = (-1,)
if include_source_map:
- source_map = origin_info.source_map(nodes, source, f.name, indices)
+ source_map = origin_info.create_source_map(nodes, source, f.name, indices)
# TODO(mdan): Try flush() and delete=False instead.
if delete_on_exit:
diff --git a/tensorflow/contrib/autograph/pyct/compiler_test.py b/tensorflow/python/autograph/pyct/compiler_test.py
index cf783da6a3..6fa289d3cc 100644
--- a/tensorflow/contrib/autograph/pyct/compiler_test.py
+++ b/tensorflow/python/autograph/pyct/compiler_test.py
@@ -22,8 +22,8 @@ import textwrap
import gast
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index eef74599a7..eef74599a7 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
index 1a212f676a..f3eb027822 100644
--- a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -22,7 +22,7 @@ from functools import wraps
import six
-from tensorflow.contrib.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index b60651a30e..102bd42c91 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -23,9 +23,9 @@ import tokenize
import gast
import six
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.util import tf_inspect
@@ -75,7 +75,7 @@ class OriginInfo(
# TODO(mdan): This source map should be a class - easier to refer to.
-def source_map(nodes, code, filename, indices_in_code):
+def create_source_map(nodes, code, filename, indices_in_code):
"""Creates a source map between an annotated AST and the code it compiles to.
Args:
diff --git a/tensorflow/contrib/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py
index eeaa13007e..3b1d5f2040 100644
--- a/tensorflow/contrib/autograph/pyct/origin_info_test.py
+++ b/tensorflow/python/autograph/pyct/origin_info_test.py
@@ -18,58 +18,50 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import origin_info
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
class OriginInfoTest(test.TestCase):
- def test_source_map(self):
+ def test_create_source_map(self):
def test_fn(x):
- if x > 0:
- x += 1
- return x
-
- node, source = parser.parse_entity(test_fn)
+ return x + 1
+
+ node, _ = parser.parse_entity(test_fn)
+ fake_origin = origin_info.OriginInfo(
+ loc=origin_info.Location('fake_filename', 3, 7),
+ function_name='fake_function_name',
+ source_code_line='fake source line',
+ comment=None)
fn_node = node.body[0]
- origin_info.resolve(fn_node, source)
-
- # Insert a traced line.
- new_node = parser.parse_str('x = abs(x)').body[0]
- anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
- fn_node.body.insert(0, new_node)
+ anno.setanno(fn_node.body[0], anno.Basic.ORIGIN, fake_origin)
+ converted_code = compiler.ast_to_source(fn_node)
- # Insert an untraced line.
- fn_node.body.insert(0, parser.parse_str('x = 0').body[0])
+ source_map = origin_info.create_source_map(
+ fn_node, converted_code, 'test_filename', [0])
- modified_source = compiler.ast_to_source(fn_node)
+ loc = origin_info.LineLocation('test_filename', 2)
+ self.assertIn(loc, source_map)
+ self.assertIs(source_map[loc], fake_origin)
- source_map = origin_info.source_map(fn_node, modified_source,
- 'test_filename', [0])
+ def test_source_map_no_origin(self):
- loc = origin_info.LineLocation('test_filename', 1)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, 'def test_fn(x):')
- self.assertEqual(origin.loc.lineno, 1)
+ def test_fn(x):
+ return x + 1
- # The untraced line, inserted second.
- loc = origin_info.LineLocation('test_filename', 2)
- self.assertFalse(loc in source_map)
+ node, _ = parser.parse_entity(test_fn)
+ fn_node = node.body[0]
+ converted_code = compiler.ast_to_source(fn_node)
- # The traced line, inserted first.
- loc = origin_info.LineLocation('test_filename', 3)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, ' if x > 0:')
- self.assertEqual(origin.loc.lineno, 2)
+ source_map = origin_info.create_source_map(
+ fn_node, converted_code, 'test_filename', [0])
- loc = origin_info.LineLocation('test_filename', 4)
- origin = source_map[loc]
- self.assertEqual(origin.source_code_line, ' if x > 0:')
- self.assertEqual(origin.loc.lineno, 2)
+ self.assertEqual(len(source_map), 0)
def test_resolve(self):
@@ -79,6 +71,7 @@ class OriginInfoTest(test.TestCase):
node, source = parser.parse_entity(test_fn)
fn_node = node.body[0]
+
origin_info.resolve(fn_node, source)
origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
diff --git a/tensorflow/contrib/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py
index 112ed46a1e..63686350d5 100644
--- a/tensorflow/contrib/autograph/pyct/parser.py
+++ b/tensorflow/python/autograph/pyct/parser.py
@@ -31,8 +31,21 @@ from tensorflow.python.util import tf_inspect
def parse_entity(entity):
"""Returns the AST of given entity."""
source = tf_inspect.getsource(entity)
+ # Comments and multiline strings can appear at arbitrary indentation levels,
+ # causing textwrap.dedent to not correctly dedent source code.
+ # TODO(b/115884650): Automatic handling of comments/multiline strings.
source = textwrap.dedent(source)
- return parse_str(source), source
+ try:
+ return parse_str(source), source
+ except IndentationError:
+ # Because we are parsing the source code of entities that have already
+ # successfully parsed once, any IndentationErrors are guaranteed to be
+ # caused by insufficient dedenting.
+ raise ValueError(
+ 'Failed to dedent prior to parsing source code. If you have comments '
+ 'or multiline strings in your code, try indenting them. '
+ 'Multiline strings can be rewritten using textwrap.dedent.\n'
+ 'Offending source code: \n %s' % source)
def parse_str(src):
diff --git a/tensorflow/contrib/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py
index 007a4c6fb0..d3a7b7a014 100644
--- a/tensorflow/contrib/autograph/pyct/parser_test.py
+++ b/tensorflow/python/autograph/pyct/parser_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import textwrap
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
@@ -42,6 +42,22 @@ class ParserTest(test.TestCase):
"""))
self.assertEqual('f', mod.body[0].name)
+ def test_parse_comments(self):
+ def f():
+# unindented comment
+ pass
+ with self.assertRaises(ValueError):
+ parser.parse_entity(f)
+
+ def test_parse_multiline_strings(self):
+ def f():
+ print("""
+some
+multiline
+string""")
+ with self.assertRaises(ValueError):
+ parser.parse_entity(f)
+
def test_parse_expression(self):
node = parser.parse_expression('a.b')
self.assertEqual('a', node.value.id)
diff --git a/tensorflow/contrib/autograph/pyct/pretty_printer.py b/tensorflow/python/autograph/pyct/pretty_printer.py
index bacc1e4a77..bacc1e4a77 100644
--- a/tensorflow/contrib/autograph/pyct/pretty_printer.py
+++ b/tensorflow/python/autograph/pyct/pretty_printer.py
diff --git a/tensorflow/contrib/autograph/pyct/pretty_printer_test.py b/tensorflow/python/autograph/pyct/pretty_printer_test.py
index 0cb48f3576..1c76744547 100644
--- a/tensorflow/contrib/autograph/pyct/pretty_printer_test.py
+++ b/tensorflow/python/autograph/pyct/pretty_printer_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import ast
-from tensorflow.contrib.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/python/autograph/pyct/qual_names.py
index fb81404edc..334cbd7d38 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names.py
+++ b/tensorflow/python/autograph/pyct/qual_names.py
@@ -29,8 +29,8 @@ import collections
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
class Symbol(collections.namedtuple('Symbol', ['name'])):
diff --git a/tensorflow/contrib/autograph/pyct/qual_names_test.py b/tensorflow/python/autograph/pyct/qual_names_test.py
index c793c2bb39..2da4dfd787 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names_test.py
+++ b/tensorflow/python/autograph/pyct/qual_names_test.py
@@ -20,11 +20,11 @@ from __future__ import print_function
import textwrap
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct.qual_names import QN
-from tensorflow.contrib.autograph.pyct.qual_names import resolve
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.qual_names import resolve
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/python/autograph/pyct/static_analysis/BUILD
index 92eacba3fd..4a4ccdcbd1 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD
+++ b/tensorflow/python/autograph/pyct/static_analysis/BUILD
@@ -27,9 +27,9 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
],
)
@@ -41,8 +41,8 @@ py_test(
tags = ["no_windows"],
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
"@gast_archive//:gast",
],
)
@@ -54,8 +54,8 @@ py_test(
tags = ["no_windows"],
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -65,8 +65,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -76,8 +76,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
],
)
@@ -87,8 +87,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":static_analysis",
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/utils",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py b/tensorflow/python/autograph/pyct/static_analysis/__init__.py
index 9a82de735d..9a82de735d 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/__init__.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index a0182da9d1..086eda7574 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -22,13 +22,14 @@ from __future__ import division
from __future__ import print_function
import copy
+import weakref
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# TODO(mdan): Add support for PY3 (e.g. Param vs arg).
# TODO(alexbw): Ignore named literals (e.g. None)
@@ -126,7 +127,10 @@ class Scope(object):
self.parent.mark_read(name)
def mark_param(self, name, owner):
- self.params[name] = owner
+ # Assumption: all AST nodes have the same life span. This lets us use
+ # a weak reference to mark the connection between a symbol node and the
+ # function node whose argument that symbol is.
+ self.params[name] = weakref.ref(owner)
def mark_creation(self, name, writes_create_symbol=False):
"""Mark a qualified name as created."""
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
index e940516190..d4a6ce8ac3 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -20,13 +20,13 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.qual_names import QN
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/python/autograph/pyct/static_analysis/annos.py
index 5eefecf278..5eefecf278 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/annos.py
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index e7baa244b2..36b9e7074d 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -25,14 +25,15 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
# TODO(aqj): Do we need this? Do other builtins fail in similar ways
# See b/114389775 for a related bug in pyct
# These symbols are legal in Python, but don't appear in the namespace.
-_special_symbols = {'range': range}
+_SPECIAL_SYMBOLS = {'range': range, 'print': print}
class LiveValueResolver(transformer.Base):
@@ -71,8 +72,10 @@ class LiveValueResolver(transformer.Base):
# If the symbol value is for example a primitive, then it will not
# have a name.
pass
- elif node.id in _special_symbols:
- anno.setanno(node, 'live_val', _special_symbols[node.id])
+ elif node.id in _SPECIAL_SYMBOLS:
+ # Note: if the user redefined any of these symbols, then they would
+ # be visible in the namespace and we would never reach this branch.
+ anno.setanno(node, 'live_val', _SPECIAL_SYMBOLS[node.id])
else:
pass
# TODO(mdan): Should we raise an error here?
@@ -86,7 +89,8 @@ class LiveValueResolver(transformer.Base):
if has_single_def:
def_, = defs
- if def_.param_of is self.enclosing_entities[0]:
+ # Note: param_of is a weakref.
+ if def_.param_of and def_.param_of() is self.enclosing_entities[0]:
if node.id in self.entity_info.arg_values:
obj = self.entity_info.arg_values[node.id]
anno.setanno(node, 'live_val', obj)
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
index fe3051179c..882c380b78 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
@@ -20,15 +20,15 @@ from __future__ import print_function
import six
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
index bf29d868a2..41c903beb9 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/liveness.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
@@ -26,10 +26,10 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import annos
class Analyzer(cfg.GraphVisitor):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
index d53adb28af..0d5f369e92 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/liveness_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
@@ -18,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import liveness
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
index 7f2b379d3d..9aaf318a9f 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
@@ -30,10 +30,10 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import annos
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import annos
class Definition(object):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
index 243fe804b2..373a2cb38f 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/reaching_definitions_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -18,13 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/python/autograph/pyct/static_analysis/type_info.py
index 835d5199fa..edb2ef0e27 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info.py
@@ -43,9 +43,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
index 404311ba24..34ba3d2f13 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
@@ -18,15 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import cfg
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis import activity
-from tensorflow.contrib.autograph.pyct.static_analysis import live_values
-from tensorflow.contrib.autograph.pyct.static_analysis import reaching_definitions
-from tensorflow.contrib.autograph.pyct.static_analysis import type_info
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
from tensorflow.python.client import session
from tensorflow.python.platform import test
from tensorflow.python.training import training
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index d81c50f524..1bf0515745 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -26,10 +26,10 @@ import textwrap
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import ast_util
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
class ReplaceTransformer(gast.NodeTransformer):
@@ -109,6 +109,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if not node.ctx:
raise ValueError('node %s is missing ctx value' % node)
+ # TODO(mdan): Rewrite _check and _set using a separate transformer.
def _check_inner_children_have_context(self, node):
if isinstance(node, gast.Attribute):
self._check_inner_children_have_context(node.value)
@@ -131,6 +132,11 @@ class ReplaceTransformer(gast.NodeTransformer):
self._check_inner_children_have_context(node.upper)
if node.step:
self._check_inner_children_have_context(node.step)
+ elif isinstance(node, gast.BinOp):
+ self._check_inner_children_have_context(node.left)
+ self._check_inner_children_have_context(node.right)
+ elif isinstance(node, gast.UnaryOp):
+ self._check_inner_children_have_context(node.operand)
elif isinstance(node, gast.Name):
self._check_has_context(node)
elif isinstance(node, (gast.Str, gast.Num)):
@@ -166,6 +172,11 @@ class ReplaceTransformer(gast.NodeTransformer):
elif isinstance(node, gast.Subscript):
self._set_inner_child_context(node.value, ctx)
self._check_inner_children_have_context(node.slice)
+ elif isinstance(node, gast.BinOp):
+ self._check_inner_children_have_context(node.left)
+ self._check_inner_children_have_context(node.right)
+ elif isinstance(node, gast.UnaryOp):
+ self._check_inner_children_have_context(node.operand)
elif isinstance(node, (gast.Str, gast.Num)):
pass
else:
diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 074105ea50..078d9a149b 100644
--- a/tensorflow/contrib/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -22,9 +22,9 @@ import imp
import gast
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
from tensorflow.python.platform import test
@@ -132,6 +132,18 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+ def test_replace_expression_context(self):
+ template = """
+ def test_fn(foo):
+ foo
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('a + 2 * b / -c'))[0]
+ self.assertIsInstance(node.body[0].ctx, gast.Load)
+ self.assertIsInstance(node.body[0].left.ctx, gast.Load)
+ self.assertIsInstance(node.body[0].right.left.right.ctx, gast.Load)
+
def test_replace_complex_context(self):
template = """
def test_fn(foo):
diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/python/autograph/pyct/testing/BUILD
index 29a92444bb..c244cbd747 100644
--- a/tensorflow/contrib/autograph/pyct/testing/BUILD
+++ b/tensorflow/python/autograph/pyct/testing/BUILD
@@ -22,8 +22,8 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
- "//tensorflow/contrib/autograph/utils",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
"@gast_archive//:gast",
],
)
@@ -41,8 +41,8 @@ py_test(
],
deps = [
":testing",
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
"@gast_archive//:gast",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen.py b/tensorflow/python/autograph/pyct/testing/codegen.py
index 279e7c09dc..78b24390c3 100644
--- a/tensorflow/contrib/autograph/pyct/testing/codegen.py
+++ b/tensorflow/python/autograph/pyct/testing/codegen.py
@@ -24,7 +24,7 @@ import string
import gast
import numpy as np
-from tensorflow.contrib.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import templates
class NodeSampler(object):
diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py b/tensorflow/python/autograph/pyct/testing/codegen_test.py
index 255c3b2a2e..71665be039 100644
--- a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py
+++ b/tensorflow/python/autograph/pyct/testing/codegen_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct.testing import codegen
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct.testing import codegen
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py
index 969ca12244..520f5038da 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/python/autograph/pyct/transformer.py
@@ -23,9 +23,9 @@ import sys
import gast
import six
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import compiler
-from tensorflow.contrib.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import pretty_printer
class AutographParseError(SyntaxError):
diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py
index a37e922a1d..23bf9a8e16 100644
--- a/tensorflow/contrib/autograph/pyct/transformer_test.py
+++ b/tensorflow/python/autograph/pyct/transformer_test.py
@@ -20,9 +20,9 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
-from tensorflow.contrib.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD
index 4504a5c7a3..22451d4f3f 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/python/autograph/utils/BUILD
@@ -32,10 +32,10 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
"//tensorflow/python:dtypes",
"//tensorflow/python:list_ops",
"//tensorflow/python:script_ops",
+ "//tensorflow/python/autograph/pyct",
"//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six",
],
diff --git a/tensorflow/contrib/tensorboard/plugins/trace/__init__.py b/tensorflow/python/autograph/utils/__init__.py
index 2c99f4077e..c781958481 100644
--- a/tensorflow/contrib/tensorboard/plugins/trace/__init__.py
+++ b/tensorflow/python/autograph/utils/__init__.py
@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Public API for the Trace plugin."""
+"""Utility module that contains APIs usable in the generated code."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=wildcard-import
-from tensorflow.contrib.tensorboard.plugins.trace.trace import *
-from tensorflow.contrib.tensorboard.plugins.trace.trace_info_pb2 import *
-# pylint: enable=wildcard-import
+from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns
+from tensorflow.python.autograph.utils.misc import alias_tensors
+from tensorflow.python.autograph.utils.multiple_dispatch import run_cond
+from tensorflow.python.autograph.utils.py_func import wrap_py_func
+from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
+from tensorflow.python.autograph.utils.testing import fake_tf
+from tensorflow.python.autograph.utils.type_check import is_tensor
diff --git a/tensorflow/contrib/autograph/utils/context_managers.py b/tensorflow/python/autograph/utils/context_managers.py
index 3d150a9581..3d150a9581 100644
--- a/tensorflow/contrib/autograph/utils/context_managers.py
+++ b/tensorflow/python/autograph/utils/context_managers.py
diff --git a/tensorflow/contrib/autograph/utils/context_managers_test.py b/tensorflow/python/autograph/utils/context_managers_test.py
index 42e27724b9..7f0a15b076 100644
--- a/tensorflow/contrib/autograph/utils/context_managers_test.py
+++ b/tensorflow/python/autograph/utils/context_managers_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import context_managers
+from tensorflow.python.autograph.utils import context_managers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import tensor_array_ops
diff --git a/tensorflow/contrib/autograph/utils/misc.py b/tensorflow/python/autograph/utils/misc.py
index 1b06caf0bd..1b06caf0bd 100644
--- a/tensorflow/contrib/autograph/utils/misc.py
+++ b/tensorflow/python/autograph/utils/misc.py
diff --git a/tensorflow/contrib/autograph/utils/misc_test.py b/tensorflow/python/autograph/utils/misc_test.py
index 968ea03df6..8d2b0d6e13 100644
--- a/tensorflow/contrib/autograph/utils/misc_test.py
+++ b/tensorflow/python/autograph/utils/misc_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.misc import alias_tensors
+from tensorflow.python.autograph.utils.misc import alias_tensors
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.ops.variables import Variable
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch.py b/tensorflow/python/autograph/utils/multiple_dispatch.py
index 70eef5676f..107c8f7a68 100644
--- a/tensorflow/contrib/autograph/utils/multiple_dispatch.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch.py
@@ -18,20 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.type_check import is_tensor
+from tensorflow.python.autograph.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
-def dynamic_is(left, right):
- # TODO(alexbw) if we're sure we should leave 'is' in place,
- # then change the semantics in converters/logical_expressions.py
- return left is right
-
-
-def dynamic_is_not(left, right):
- return left is not right
-
-
def run_cond(condition, true_fn, false_fn):
"""Type-dependent functional conditional.
diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
index f72f8e94a0..2a77c895ce 100644
--- a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
@@ -18,9 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.autograph.utils import multiple_dispatch
+from tensorflow.python.autograph.utils import multiple_dispatch
from tensorflow.python.client.session import Session
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.platform import test
@@ -28,33 +26,6 @@ from tensorflow.python.platform import test
class MultipleDispatchTest(test.TestCase):
- def test_dynamic_is_python(self):
- a = np.eye(3)
- also_a = a
- not_actually_a = np.eye(3)
- should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
- should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
- should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
- should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1)
- self.assertTrue(should_be_true2)
- self.assertFalse(should_be_false1)
- self.assertFalse(should_be_false2)
-
- def test_dynamic_is_tf(self):
- with Session().as_default():
- a = constant([2.0])
- also_a = a
- not_actually_a = constant([2.0])
- should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
- should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
- should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
- should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1)
- self.assertTrue(should_be_true2)
- self.assertFalse(should_be_false1)
- self.assertFalse(should_be_false2)
-
def test_run_cond_python(self):
true_fn = lambda: (2,)
false_fn = lambda: (3,)
diff --git a/tensorflow/contrib/autograph/utils/py_func.py b/tensorflow/python/autograph/utils/py_func.py
index 11ebfb2e49..11ebfb2e49 100644
--- a/tensorflow/contrib/autograph/utils/py_func.py
+++ b/tensorflow/python/autograph/utils/py_func.py
diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/python/autograph/utils/py_func_test.py
index f60b57bcce..1c220d9492 100644
--- a/tensorflow/contrib/autograph/utils/py_func_test.py
+++ b/tensorflow/python/autograph/utils/py_func_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/autograph/utils/tensor_list.py b/tensorflow/python/autograph/utils/tensor_list.py
index 2556f41289..2556f41289 100644
--- a/tensorflow/contrib/autograph/utils/tensor_list.py
+++ b/tensorflow/python/autograph/utils/tensor_list.py
diff --git a/tensorflow/contrib/autograph/utils/tensor_list_test.py b/tensorflow/python/autograph/utils/tensor_list_test.py
index faaf7b7877..697c166eb1 100644
--- a/tensorflow/contrib/autograph/utils/tensor_list_test.py
+++ b/tensorflow/python/autograph/utils/tensor_list_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import tensor_list as tl
+from tensorflow.python.autograph.utils import tensor_list as tl
from tensorflow.python.client.session import Session
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/python/autograph/utils/tensors.py
index fa5db81a71..fa5db81a71 100644
--- a/tensorflow/contrib/autograph/utils/tensors.py
+++ b/tensorflow/python/autograph/utils/tensors.py
diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/python/autograph/utils/tensors_test.py
index e855e0b6cb..1e7cfec9e1 100644
--- a/tensorflow/contrib/autograph/utils/tensors_test.py
+++ b/tensorflow/python/autograph/utils/tensors_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.autograph.utils import tensors
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import list_ops
diff --git a/tensorflow/contrib/autograph/utils/testing.py b/tensorflow/python/autograph/utils/testing.py
index cb4785d0dc..cb4785d0dc 100644
--- a/tensorflow/contrib/autograph/utils/testing.py
+++ b/tensorflow/python/autograph/utils/testing.py
diff --git a/tensorflow/contrib/autograph/utils/type_check.py b/tensorflow/python/autograph/utils/type_check.py
index 8748abc47b..8748abc47b 100644
--- a/tensorflow/contrib/autograph/utils/type_check.py
+++ b/tensorflow/python/autograph/utils/type_check.py
diff --git a/tensorflow/contrib/autograph/utils/type_check_test.py b/tensorflow/python/autograph/utils/type_check_test.py
index 3b67b7194c..b3d1304e16 100644
--- a/tensorflow/contrib/autograph/utils/type_check_test.py
+++ b/tensorflow/python/autograph/utils/type_check_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
import numpy
-from tensorflow.contrib.autograph.utils import type_check
+from tensorflow.python.autograph.utils import type_check
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index ae0ad27f15..c963cfd334 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -178,16 +178,30 @@ def register_session_run_conversion_functions(
feed_function_for_partial_run: A callable for specifying tensor values to
feed when setting up a partial run, which takes a `tensor_type` type
object as input, and returns a list of Tensors.
+
+ Raises:
+ ValueError: If `tensor_type` has already been registered.
"""
for conversion_function in _REGISTERED_EXPANSIONS:
if issubclass(conversion_function[0], tensor_type):
- raise ValueError('%s has already been registered so ignore it.',
+ raise ValueError('%s has already been registered so ignore it.' %
tensor_type)
- return
+
_REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
feed_function_for_partial_run))
+def _is_attrs_instance(obj):
+ """Returns True if the given obj is an instance of attrs-decorated class."""
+ return getattr(obj.__class__, '__attrs_attrs__', None) is not None
+
+
+def _get_attrs_values(obj):
+ """Returns the list of values from an attrs instance."""
+ attrs = getattr(obj.__class__, '__attrs_attrs__')
+ return [getattr(obj, a.name) for a in attrs]
+
+
class _FetchMapper(object):
"""Definition of the interface provided by fetch mappers.
@@ -247,6 +261,8 @@ class _FetchMapper(object):
return _ListFetchMapper(fetch)
elif isinstance(fetch, collections.Mapping):
return _DictFetchMapper(fetch)
+ elif _is_attrs_instance(fetch):
+ return _AttrsFetchMapper(fetch)
else:
# Look for a handler in the registered expansions.
for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
@@ -398,6 +414,32 @@ class _DictFetchMapper(_FetchMapper):
return results
+class _AttrsFetchMapper(_FetchMapper):
+ """Fetch mapper for attrs decorated classes."""
+
+ def __init__(self, fetches):
+ """Creates a _AttrsFetchMapper.
+
+ Args:
+ fetches: An instance of an attrs decorated class.
+ """
+ values = _get_attrs_values(fetches)
+ self._fetch_type = type(fetches)
+ self._mappers = [
+ _FetchMapper.for_fetch(fetch) for fetch in values
+ ]
+ self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
+
+ def unique_fetches(self):
+ return self._unique_fetches
+
+ def build_results(self, values):
+ results = []
+ for m, vi in zip(self._mappers, self._value_indices):
+ results.append(m.build_results([values[j] for j in vi]))
+ return self._fetch_type(*results)
+
+
class _FetchHandler(object):
"""Handler for structured fetches.
diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc
new file mode 100644
index 0000000000..4d361612b7
--- /dev/null
+++ b/tensorflow/python/client/session_ref.cc
@@ -0,0 +1,525 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/client/session_ref.h"
+
+#include <stdlib.h>
+#include <memory>
+#include <utility>
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/named_tensor.pb.h"
+#include "tensorflow/core/protobuf/replay_log.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Scope helper to track active calls and manage session lifetime.
+// SessionRef blocks closing until all active calls complete or are cancelled.
+struct RunCounter {
+ std::shared_ptr<Session> session;
+ uint64* value;
+ mutex* m;
+ condition_variable* cv;
+
+ explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
+ condition_variable* cv)
+ : session(std::move(s)), value(v), m(m), cv(cv) {
+ mutex_lock l(*m);
+ ++*value;
+ }
+
+ ~RunCounter() {
+ mutex_lock l(*m);
+ if (--*value == 0) {
+ cv->notify_all();
+ }
+ }
+};
+
+std::string SessionToHandle(Session* session) {
+ return strings::Printf("%llu", reinterpret_cast<uint64>(session));
+}
+
+// The Session interface has many methods of the form:
+//
+// X(a, b);
+// X(RunOptions, a, b);
+//
+// Not all sessions support the second case (with an empty RunOptions()).
+// We use this variable as a sentinel to dispatch to the correct call.
+RunOptions* kEmptyRunOptions() {
+ static RunOptions* options = new RunOptions();
+ return options;
+}
+
+} // namespace
+
+// Run the given session operation, recording start and end timestamps.
+// If the operation returns a bad status, return after flushing the current
+// log request. This should be run _after_ all request information has been
+// added to the current op.
+#define RUN_WITH_TIMESTAMP(OpName, ...) \
+ op.set_start_time_us(Env::Default()->NowMicros()); \
+ Status status = session->OpName(__VA_ARGS__); \
+ op.set_end_time_us(Env::Default()->NowMicros()); \
+ if (!status.ok()) { \
+ Flush(op).IgnoreError(); \
+ return status; \
+ }
+
+// Records requests (and optionally responses) performed against a session.
+// The resulting replay log can be used with the `tf_replay` tool to replicate
+// the operations against a simulated environment, without requiring the
+// original code or cluster setup.
+//
+// Session logging by setting the TF_REPLAY_LOG_FILE environment variable.
+class SessionLogger {
+ public:
+ SessionLogger() {
+ std::string log_name = getenv("TF_REPLAY_LOG_FILE");
+ LOG(INFO) << "Constructing new session logger for " << log_name;
+ TF_CHECK_OK(
+ Env::Default()->RecursivelyCreateDir(string(io::Dirname(log_name))));
+ Env::Default()->DeleteFile(log_name).IgnoreError();
+
+ TF_CHECK_OK(Env::Default()->NewWritableFile(log_name, &log_file_));
+ log_writer_ = absl::make_unique<io::RecordWriter>(log_file_.get());
+ }
+
+ ~SessionLogger() {
+ log_writer_->Close().IgnoreError();
+ log_writer_.release();
+ log_file_->Close().IgnoreError();
+ }
+
+ Status RecordNewSession(Session* session) {
+ LOG(INFO) << "New session discovered. Capturing devices...";
+ ReplayOp op;
+ NewReplaySession* req = op.mutable_new_replay_session();
+
+ std::vector<DeviceAttributes> devices;
+ Status status = session->ListDevices(&devices);
+ if (status.ok()) {
+ LOG(INFO) << "Found: " << devices.size() << " devices.";
+ for (const DeviceAttributes& dev : devices) {
+ *req->mutable_devices()->add_local_device() = dev;
+ }
+ } else {
+ LOG(WARNING) << "Failed to list devices on session. Continuing.";
+ }
+
+ req->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordRun(Session* session,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names,
+ target_node_names, outputs, nullptr);
+ }
+
+ Status RecordRun(Session* session, const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs, RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = run_options;
+
+ for (const auto& it : inputs) {
+ NamedTensorProto* feed = req->add_feed();
+ feed->set_name(it.first);
+ it.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ // Build an index from fetch tensor name to first index in
+ // output_tensor_names.
+ std::unordered_map<string, int> output_name_to_offset;
+ for (int i = 0; i < output_tensor_names.size(); ++i) {
+ const string& name = output_tensor_names[i];
+ if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
+ req->add_fetch(name);
+ }
+ }
+ for (const string& target : target_node_names) {
+ req->add_target(target);
+ }
+
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+ } else {
+ RUN_WITH_TIMESTAMP(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+ }
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_tensor_names[i]);
+ }
+
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordCreate(Session* session, const GraphDef& graph) {
+ return RecordCreate(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CreateRequest)
+ Status RecordCreate(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ CreateSessionRequest* req = op.mutable_create_session();
+ *req->mutable_graph_def() = graph;
+
+ CreateSessionResponse* resp = op.mutable_create_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Create, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Create, run_options, graph);
+ }
+ resp->set_session_handle(SessionToHandle(session));
+ return Flush(op);
+ }
+
+ Status RecordExtend(Session* session, const GraphDef& graph) {
+ return RecordExtend(session, *kEmptyRunOptions(), graph);
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in ExtendRequest)
+ Status RecordExtend(Session* session, const RunOptions& run_options,
+ const GraphDef& graph) {
+ ReplayOp op;
+ ExtendSessionRequest* req = op.mutable_extend_session();
+ op.mutable_extend_session_response();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_graph_def() = graph;
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Extend, graph);
+ } else {
+ RUN_WITH_TIMESTAMP(Extend, run_options, graph);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordClose(Session* session) {
+ return RecordClose(session, *kEmptyRunOptions());
+ }
+
+ // N.B. RunOptions is not stored (it has no entry in CloseRequest)
+ Status RecordClose(Session* session, const RunOptions& run_options) {
+ ReplayOp op;
+ CloseSessionRequest* req = op.mutable_close_session();
+ req->set_session_handle(SessionToHandle(session));
+ op.mutable_close_session_response();
+ if (&run_options == kEmptyRunOptions()) {
+ RUN_WITH_TIMESTAMP(Close);
+ } else {
+ RUN_WITH_TIMESTAMP(Close, run_options);
+ }
+ return Flush(op);
+ }
+
+ Status RecordListDevices(Session* session,
+ std::vector<DeviceAttributes>* response) {
+ ReplayOp op;
+ ListDevicesRequest* req = op.mutable_list_devices();
+ ListDevicesResponse* resp = op.mutable_list_devices_response();
+ req->set_session_handle(SessionToHandle(session));
+ RUN_WITH_TIMESTAMP(ListDevices, response);
+
+ // TODO(power) -- local vs remote device distinction is lost here!
+ *resp->mutable_local_device() = {response->begin(), response->end()};
+ return Flush(op);
+ }
+
+ Status RecordPRunSetup(Session* session,
+ const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ ReplayOp op;
+ PartialRunSetupRequest* req = op.mutable_partial_run_setup();
+ req->set_session_handle(SessionToHandle(session));
+ for (auto& input : input_names) {
+ req->add_feed(input);
+ }
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+ for (auto& target : target_nodes) {
+ req->add_target(target);
+ }
+ RUN_WITH_TIMESTAMP(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+ op.mutable_partial_run_setup_response()->set_partial_run_handle(*handle);
+ return Flush(op);
+ }
+
+ Status RecordPRun(Session* session, const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ ReplayOp op;
+ RunStepRequest* req = op.mutable_run_step();
+ RunStepResponse* resp = op.mutable_run_step_response();
+ req->set_session_handle(SessionToHandle(session));
+
+ // Mark this step as a partial run for replay.
+ req->set_partial_run_handle(handle);
+ for (auto& input : inputs) {
+ auto* feed = req->add_feed();
+ feed->set_name(input.first);
+ input.second.AsProtoField(feed->mutable_tensor());
+ }
+
+ for (auto& output : output_names) {
+ req->add_fetch(output);
+ }
+
+ RUN_WITH_TIMESTAMP(PRun, handle, inputs, output_names, outputs);
+
+ for (size_t i = 0; i < outputs->size(); ++i) {
+ const Tensor& tensor = (*outputs)[i];
+ NamedTensorProto* tproto = resp->add_tensor();
+ tensor.AsProtoField(tproto->mutable_tensor());
+ tproto->set_name(output_names[i]);
+ }
+
+ return Flush(op);
+ }
+
+ Status RecordMakeCallable(Session* session,
+ const CallableOptions& callable_options,
+ Session::CallableHandle* handle) {
+ ReplayOp op;
+ MakeCallableRequest* req = op.mutable_make_callable();
+ req->set_session_handle(SessionToHandle(session));
+ *req->mutable_options() = callable_options;
+
+ RUN_WITH_TIMESTAMP(MakeCallable, callable_options, handle);
+
+ MakeCallableResponse* resp = op.mutable_make_callable_response();
+ resp->set_handle(*handle);
+
+ return Flush(op);
+ }
+
+ Status RecordRunCallable(Session* session, Session::CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ ReplayOp op;
+ RunCallableRequest* req = op.mutable_run_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ for (auto& tensor : feed_tensors) {
+ tensor.AsProtoField(req->add_feed());
+ }
+ RUN_WITH_TIMESTAMP(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+
+ RunCallableResponse* resp = op.mutable_run_callable_response();
+ if (run_metadata) {
+ *resp->mutable_metadata() = *run_metadata;
+ }
+ for (const Tensor& tensor : *fetch_tensors) {
+ tensor.AsProtoTensorContent(resp->add_fetch());
+ }
+ return Flush(op);
+ }
+
+ Status RecordReleaseCallable(Session* session,
+ Session::CallableHandle handle) {
+ ReplayOp op;
+ ReleaseCallableRequest* req = op.mutable_release_callable();
+ req->set_session_handle(SessionToHandle(session));
+ req->set_handle(handle);
+ RUN_WITH_TIMESTAMP(ReleaseCallable, handle);
+ return Flush(op);
+ }
+
+ private:
+ Status Flush(const ReplayOp& op) {
+ mutex_lock l(log_mutex_);
+
+ string buf;
+ op.SerializeToString(&buf);
+ TF_RETURN_IF_ERROR(log_writer_->WriteRecord(buf));
+
+ // TODO(b/116624106): Not all file-systems respect calls to `Sync()`
+ return log_file_->Sync();
+ }
+
+ std::unique_ptr<WritableFile> log_file_;
+ std::unique_ptr<io::RecordWriter> log_writer_;
+ mutex log_mutex_;
+};
+
+static SessionLogger* global_session_logger() {
+ static SessionLogger* logger = new SessionLogger();
+ return logger;
+}
+
+SessionRef::SessionRef(Session* session) : session_(session) {
+ if (getenv("TF_REPLAY_LOG_FILE") != nullptr) {
+ logger_ = global_session_logger();
+ logger_->RecordNewSession(this->session_.get()).IgnoreError();
+ } else {
+ logger_ = nullptr;
+ }
+}
+
+SessionRef::~SessionRef() = default;
+
+Status SessionRef::CheckNotClosed() {
+ mutex_lock l(run_lock_);
+ if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
+ return ::tensorflow::Status::OK();
+}
+
+// If logging is active, log the start and end time of the operation along with
+// the request and response.
+#define LOG_AND_RUN_OPERATION(OpName, ...) \
+ TF_RETURN_IF_ERROR(CheckNotClosed()); \
+ RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); \
+ if (!logger_) { \
+ return rc.session->OpName(__VA_ARGS__); \
+ } \
+ return logger_->Record##OpName(rc.session.get(), __VA_ARGS__);
+
+Status SessionRef::Run(const RunOptions& run_options,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names,
+ target_node_names, outputs, run_metadata);
+}
+
+Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_tensor_names,
+ const std::vector<string>& target_node_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names,
+ outputs);
+}
+
+Status SessionRef::Create(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, graph);
+}
+
+Status SessionRef::Create(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Create, run_options, graph);
+}
+
+Status SessionRef::Extend(const RunOptions& run_options,
+ const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, run_options, graph);
+}
+
+Status SessionRef::Extend(const GraphDef& graph) {
+ LOG_AND_RUN_OPERATION(Extend, graph);
+}
+
+Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
+ LOG_AND_RUN_OPERATION(ListDevices, response);
+}
+
+Status SessionRef::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes,
+ handle);
+}
+
+Status SessionRef::PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs);
+}
+
+Status SessionRef::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle);
+}
+
+Status SessionRef::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors,
+ run_metadata);
+}
+
+Status SessionRef::ReleaseCallable(CallableHandle handle) {
+ LOG_AND_RUN_OPERATION(ReleaseCallable, handle);
+}
+
+Status SessionRef::Close(const RunOptions& run_options) {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get(), run_options);
+ } else {
+ status = session_->Close(run_options);
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+Status SessionRef::Close() {
+ TF_RETURN_IF_ERROR(CheckNotClosed());
+ mutex_lock l(run_lock_);
+ Status status;
+ if (logger_) {
+ status = logger_->RecordClose(session_.get());
+ } else {
+ status = session_->Close();
+ }
+ session_.reset();
+ while (run_count_ > 0) {
+ run_finished_.wait(l);
+ }
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/python/client/session_ref.h
index 9459e7edbe..b0fb12b189 100644
--- a/tensorflow/core/common_runtime/session_ref.h
+++ b/tensorflow/python/client/session_ref.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#ifndef TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
+#define TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
#include <memory>
@@ -22,6 +22,8 @@ limitations under the License.
namespace tensorflow {
+class SessionLogger;
+
// A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
//
// SessionRef blocks the return of Close() until all pending operations have
@@ -29,8 +31,8 @@ namespace tensorflow {
// subsequent operations on the SessionRef object will return errors::Cancelled.
class SessionRef : public Session {
public:
- SessionRef(Session* session) : session_(session) {}
- virtual ~SessionRef() {}
+ explicit SessionRef(Session* session);
+ ~SessionRef() override;
Status Create(const GraphDef& graph) override;
Status Extend(const GraphDef& graph) override;
@@ -78,9 +80,12 @@ class SessionRef : public Session {
uint64 run_count_ GUARDED_BY(run_lock_) = {0};
std::shared_ptr<Session> session_;
+ // Borrowed reference to global session logger.
+ SessionLogger* logger_;
+
Status CheckNotClosed();
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
+#endif // TENSORFLOW_PYTHON_CLIENT_SESSION_REF_H_
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 4afc6399d5..f576435136 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -61,6 +61,12 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
+
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
@@ -300,6 +306,82 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(None, res[2])
self.assertEqual(44.0, res[1])
+ def testFetchAttrs(self):
+ if attr is None:
+ self.skipTest('attr module is unavailable.')
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ val1 = np.array([1.2, 3.4, 5.6])
+ val2 = np.array([[1, 2], [4, 3]])
+ val3 = np.array([10, 20, 30])
+
+ t1 = constant_op.constant(val1)
+ t2 = constant_op.constant(val2)
+
+ sample = SampleAttr(t1, t2)
+ with session.Session() as sess:
+ result = sess.run(sample)
+ self.assertIsInstance(result, SampleAttr)
+ self.assertAllEqual(val1, result.field1)
+ self.assertAllEqual(val2, result.field2)
+
+ result = sess.run(sample, feed_dict={sample.field1: val3})
+ self.assertIsInstance(result, SampleAttr)
+ self.assertAllEqual(val3, result.field1)
+ self.assertAllEqual(val2, result.field2)
+
+ def testFetchNestedAttrs(self):
+ if attr is None:
+ self.skipTest('attr module is unavailable.')
+
+ @attr.s
+ class SampleAttr(object):
+ field0 = attr.ib()
+ field1 = attr.ib()
+
+ v1 = 10
+ v2 = 20
+ v3 = np.float32(1.2)
+ v4 = np.float32(3.4)
+ v5 = np.float64(100.001)
+ v6 = np.float64(-23.451)
+ arr1 = np.array([1.2, 6.7, 3.4])
+ arr2 = np.array([7, 11, 3])
+ sample = SampleAttr(
+ SampleAttr(
+ SampleAttr(constant_op.constant(v1), constant_op.constant(v2)),
+ SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))),
+ {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)),
+ 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]})
+
+ with session.Session() as sess:
+ result = sess.run(sample)
+ self.assertIsInstance(result, SampleAttr)
+ self.assertIsInstance(result.field0, SampleAttr)
+ self.assertIsInstance(result.field0.field0, SampleAttr)
+ self.assertIsInstance(result.field0.field1, SampleAttr)
+ self.assertIsInstance(result.field0.field1.field0, np.ndarray)
+ self.assertAllEqual(arr1, result.field0.field1.field0)
+ self.assertIsInstance(result.field0.field1.field1, np.ndarray)
+ self.assertAllEqual(arr2, result.field0.field1.field1)
+ self.assertIsInstance(result.field1, dict)
+ self.assertIn('A', result.field1)
+ self.assertIn('B', result.field1)
+ self.assertIsInstance(result.field1['A'], SampleAttr)
+ self.assertAllEqual(
+ [v3, v4],
+ [result.field1['A'].field0, result.field1['A'].field1])
+ self.assertIsInstance(result.field1['B'], list)
+ self.assertEqual(1, len(result.field1['B']))
+ self.assertIsInstance(result.field1['B'][0], SampleAttr)
+ self.assertAllEqual(
+ [v5, v6],
+ [result.field1['B'][0].field0, result.field1['B'][0].field1])
+
def testFetchNestingEmptyOneLevel(self):
with session.Session() as sess:
a_val = 11.0
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 39a2922ac0..ef7527d887 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -463,7 +463,7 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
}
// Override default py3 behavior of attempting to encode into Unicode.
-%typemap(out) std::string tensorflow::GetResourceHandleShapeAndType {
+%typemap(out) std::string tensorflow::GetHandleShapeAndType {
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
}
@@ -782,7 +782,7 @@ def TF_Reset(target, containers=None, config=None):
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
%unignore ExtendSession;
-%unignore ResourceHandleShapeAndType;
+%unignore HandleShapeAndType;
%include "tensorflow/python/client/tf_session_helper.h"
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index bcd4af2912..dc0c10bab7 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
-#include "tensorflow/core/common_runtime/session_ref.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
@@ -31,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/equal_graph_def.h"
+#include "tensorflow/python/client/session_ref.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py
index 1e96ac5ed4..c3f38294b5 100644
--- a/tensorflow/python/client/timeline.py
+++ b/tensorflow/python/client/timeline.py
@@ -588,7 +588,8 @@ class Timeline(object):
alloc_tensor_set = set()
alloc_maxes[allocator] = AllocationMaximum(
timestamp=0, num_bytes=0, tensors=set())
- for time, num_bytes, name in alloc_list:
+ for time, num_bytes, name in sorted(
+ alloc_list, key=lambda allocation: allocation[0]):
total_bytes += num_bytes
if num_bytes < 0:
alloc_tensor_set.discard(name)
diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py
index c046e9cfd4..032bbf7c4e 100644
--- a/tensorflow/python/client/timeline_test.py
+++ b/tensorflow/python/client/timeline_test.py
@@ -161,10 +161,8 @@ class TimelineTest(test.TestCase):
cpu_max = maximums[
'cuda_host_bfc'] if 'cuda_host_bfc' in maximums else maximums[cpuname]
# At least num1 + num2, both float32s (4 bytes each)
- self.assertGreater(cpu_max.num_bytes, 8)
+ self.assertGreaterEqual(cpu_max.num_bytes, 8)
self.assertGreater(cpu_max.timestamp, 0)
- self.assertTrue('num1' in cpu_max.tensors or 'num1/read' in cpu_max.tensors)
- self.assertTrue('num2' in cpu_max.tensors or 'num2/read' in cpu_max.tensors)
def testManyCPUs(self):
run_options = config_pb2.RunOptions(
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 60ebae19ab..74fe1fe35c 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 11)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 25)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index 3e08c1587e..138141f4fc 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -12,6 +12,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/data/ops:readers",
],
)
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 631b87a718..7a6f03d4d3 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -394,6 +394,7 @@ cuda_py_test(
size = "small",
srcs = ["optional_ops_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:optional_ops",
@@ -407,3 +408,53 @@ cuda_py_test(
"//tensorflow/python:tensor_shape",
],
)
+
+cuda_py_test(
+ name = "multi_device_iterator_test",
+ size = "small",
+ srcs = ["multi_device_iterator_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_windows_gpu",
+ ],
+)
+
+tf_py_test(
+ name = "window_dataset_op_test",
+ size = "small",
+ srcs = ["window_dataset_op_test.py"],
+ additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+tf_py_test(
+ name = "inputs_test",
+ size = "small",
+ srcs = ["inputs_test.py"],
+ additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py
new file mode 100644
index 0000000000..4c9279dd95
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/inputs_test.py
@@ -0,0 +1,148 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class InputsTest(test.TestCase, parameterized.TestCase):
+
+ @staticmethod
+ def make_apply_fn(dataset):
+
+ def apply_fn(dataset):
+
+ def _apply_fn(dataset):
+ return dataset.cache()
+
+ return dataset.apply(_apply_fn)
+
+ return apply_fn
+
+ @staticmethod
+ def make_gen():
+
+ def gen():
+ yield 42
+
+ return gen
+
+ @staticmethod
+ def make_interleave_fn(dataset, num_parallel_calls=None):
+
+ def interleave_fn(dataset):
+ return dataset.interleave(
+ lambda x: dataset_ops.Dataset.range(0),
+ cycle_length=2,
+ num_parallel_calls=num_parallel_calls)
+
+ return interleave_fn
+
+ @parameterized.named_parameters(
+ ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
+ ("FromGenerator",
+ dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
+ 1),
+ ("FromSparseTensorSlices",
+ dataset_ops.Dataset.from_sparse_tensor_slices(
+ sparse_tensor.SparseTensor(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])))),
+ ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
+ ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
+ ("Range", dataset_ops.Dataset.range(10)),
+ ("TextLine", readers.TextLineDataset("")),
+ ("TFRecord", readers.TFRecordDataset(""), 1),
+ )
+ def testDatasetSourceInputs(self, dataset, num_inputs=0):
+ self.assertEqual(num_inputs, len(dataset._inputs()))
+
+ @parameterized.named_parameters(
+ ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
+ ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
+ ("Filter", lambda x: x.filter(lambda x: True),
+ dataset_ops.Dataset.range(0)),
+ ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
+ ("PaddedBatch", lambda x: x.padded_batch(10, []),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelInterleave",
+ make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
+ dataset_ops.Dataset.range(0)),
+ ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
+ ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
+ ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
+ ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
+ ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
+ )
+ def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
+ self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
+
+ @parameterized.named_parameters(
+ ("Concatenate", lambda x, y: x.concatenate(y),
+ dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
+ def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
+ self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
+
+ @parameterized.named_parameters(
+ ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
+ ("ZipNest", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0),
+ (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
+ ("ZipTuple", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
+ def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
+ self.assertEqual(
+ nest.flatten(input_datasets),
+ dataset_fn(input_datasets)._inputs())
+
+ def testCollectInputs(self):
+ ds1 = dataset_ops.Dataset.range(0)
+ ds2 = ds1.concatenate(ds1)
+ ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
+
+ inputs = []
+ queue = [ds3]
+ while queue:
+ ds = queue[0]
+ queue = queue[1:]
+ queue.extend(ds._inputs())
+ inputs.append(ds)
+
+ self.assertEqual(5, inputs.count(ds1))
+ self.assertEqual(2, inputs.count(ds2))
+ self.assertEqual(1, inputs.count(ds3))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index a35cee594a..e7e51df65e 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -134,7 +134,7 @@ class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
result.append([value] * value)
return result * count
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for expected_element in self._interleave(
repeat(input_values, count), cycle_length, block_length):
self.assertEqual(expected_element, sess.run(get_next))
@@ -169,7 +169,7 @@ class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
num_parallel_calls)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for value in input_values:
if np.isnan(value):
with self.assertRaises(errors.InvalidArgumentError):
@@ -195,7 +195,7 @@ class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 7685d8dbdc..ae04995436 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -397,6 +397,28 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
# Randomness is repeatable given same seed
self.assertAllClose(random_values, random_values_2)
+ def testStatefulMapKeepsStateAcrossIterators(self):
+ iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10)
+ .map(lambda _: random_ops.random_uniform((), seed=11))
+ .repeat(1000)
+ .batch(10)
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ random_values = sess.run(get_next)
+
+ # Assert that one of the next 99 batches yielded by the iterator is
+ # different from the first.
+ i = 0
+ while i < 99:
+ if np.any(random_values != sess.run(get_next)):
+ break
+ i += 1
+ self.assertLess(i, 99)
+
def testMapDict(self):
iterator = (dataset_ops.Dataset.range(10)
.map(lambda x: {"foo": x * 2, "bar": x ** 2})
@@ -731,7 +753,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tids = sess.run(get_next)
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
new file mode 100644
index 0000000000..056664b83b
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -0,0 +1,190 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""MultiDeviceIterator tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class MultiDeviceIteratorTest(test.TestCase):
+
+ def testNoGetNext(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+
+ def testBasic(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testOneOnSameDevice(self):
+ with ops.device("/cpu:0"):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:0", "/cpu:1"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testRepeatDevices(self):
+ with ops.device("/cpu:0"):
+ dataset = dataset_ops.Dataset.range(20)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2", "/cpu:1", "/cpu:2"])
+ elements = multi_device_iterator.get_next()
+ elem_on_1, elem_on_2, elem_on_3, elem_on_4 = elements
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 20, 4):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ self.assertEqual(i + 2, sess.run(elem_on_3))
+ self.assertEqual(i + 3, sess.run(elem_on_4))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+ sess.run(elem_on_3)
+ sess.run(elem_on_4)
+
+ def testNotFullyDivisible(self):
+ dataset = dataset_ops.Dataset.range(9)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 8, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ self.assertEqual(8, sess.run(elem_on_1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testUneven(self):
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"], max_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ for i in range(0, 10, 2):
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testMultipleInitializations(self):
+ with ops.device("/cpu:0"):
+ epoch = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset1 = dataset_ops.Dataset.from_tensors(epoch).repeat(1000)
+ dataset2 = dataset_ops.Dataset.range(1000)
+ dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+ init_op = multi_device_iterator.initializer
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.test_session(config=config) as sess:
+ for i in range(1000):
+ sess.run(init_op, feed_dict={epoch: i})
+ self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
+
+ def testBasicGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/gpu:0"])
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+ def testUnevenGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ dataset = dataset_ops.Dataset.range(10)
+ multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
+ dataset, ["/cpu:1", "/gpu:0"], max_buffer_size=4)
+ elem_on_1, elem_on_2 = multi_device_iterator.get_next()
+
+ config = config_pb2.ConfigProto(device_count={"CPU": 2, "GPU": 1})
+ with self.test_session(config=config) as sess:
+ sess.run(multi_device_iterator.initializer)
+ for i in range(0, 10, 2):
+ self.assertEqual(i, sess.run(elem_on_1))
+ for i in range(0, 10, 2):
+ self.assertEqual(i + 1, sess.run(elem_on_2))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elem_on_1)
+ sess.run(elem_on_2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index c344513e71..706a65fe55 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -17,11 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,14 +35,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class OptionalTest(test.TestCase):
+class OptionalTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFromValue(self):
opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
- self.assertEqual(dtypes.float32, opt.output_types)
- self.assertEqual([], opt.output_shapes)
- self.assertEqual(ops.Tensor, opt.output_classes)
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual(37.0, self.evaluate(opt.get_value()))
@@ -50,15 +49,6 @@ class OptionalTest(test.TestCase):
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
})
- self.assertEqual({
- "a": dtypes.float32,
- "b": (dtypes.string, dtypes.string)
- }, opt.output_types)
- self.assertEqual({"a": [], "b": ([1], [])}, opt.output_shapes)
- self.assertEqual({
- "a": ops.Tensor,
- "b": (ops.Tensor, ops.Tensor)
- }, opt.output_classes)
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual({
"a": 37.0,
@@ -76,46 +66,29 @@ class OptionalTest(test.TestCase):
values=np.array([-1., 1.], dtype=np.float32),
dense_shape=np.array([2, 2]))
opt = optional_ops.Optional.from_value((st_0, st_1))
- self.assertEqual((dtypes.int64, dtypes.float32), opt.output_types)
- self.assertEqual(([1], [2, 2]), opt.output_shapes)
- self.assertEqual((sparse_tensor.SparseTensor, sparse_tensor.SparseTensor),
- opt.output_classes)
+ self.assertTrue(self.evaluate(opt.has_value()))
+ val_0, val_1 = opt.get_value()
+ for expected, actual in [(st_0, val_0), (st_1, val_1)]:
+ self.assertAllEqual(expected.indices, self.evaluate(actual.indices))
+ self.assertAllEqual(expected.values, self.evaluate(actual.values))
+ self.assertAllEqual(expected.dense_shape,
+ self.evaluate(actual.dense_shape))
@test_util.run_in_graph_and_eager_modes
def testFromNone(self):
- opt = optional_ops.Optional.none_from_structure(tensor_shape.scalar(),
- dtypes.float32, ops.Tensor)
- self.assertEqual(dtypes.float32, opt.output_types)
- self.assertEqual([], opt.output_shapes)
- self.assertEqual(ops.Tensor, opt.output_classes)
+ value_structure = structure.TensorStructure(dtypes.float32, [])
+ opt = optional_ops.Optional.none_from_structure(value_structure)
+ self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
+ self.assertFalse(
+ opt.value_structure.is_compatible_with(
+ structure.TensorStructure(dtypes.float32, [1])))
+ self.assertFalse(
+ opt.value_structure.is_compatible_with(
+ structure.TensorStructure(dtypes.int32, [])))
self.assertFalse(self.evaluate(opt.has_value()))
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(opt.get_value())
- def testStructureMismatchError(self):
- tuple_output_shapes = (tensor_shape.scalar(), tensor_shape.scalar())
- tuple_output_types = (dtypes.float32, dtypes.float32)
- tuple_output_classes = (ops.Tensor, ops.Tensor)
-
- dict_output_shapes = {
- "a": tensor_shape.scalar(),
- "b": tensor_shape.scalar()
- }
- dict_output_types = {"a": dtypes.float32, "b": dtypes.float32}
- dict_output_classes = {"a": ops.Tensor, "b": ops.Tensor}
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- tuple_output_shapes, tuple_output_types, dict_output_classes)
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- tuple_output_shapes, dict_output_types, tuple_output_classes)
-
- with self.assertRaises(TypeError):
- optional_ops.Optional.none_from_structure(
- dict_output_shapes, tuple_output_types, tuple_output_classes)
-
@test_util.run_in_graph_and_eager_modes
def testCopyToGPU(self):
if not test_util.is_gpu_available():
@@ -126,17 +99,15 @@ class OptionalTest(test.TestCase):
(constant_op.constant(37.0), constant_op.constant("Foo"),
constant_op.constant(42)))
optional_none = optional_ops.Optional.none_from_structure(
- tensor_shape.scalar(), dtypes.float32, ops.Tensor)
+ structure.TensorStructure(dtypes.float32, []))
with ops.device("/gpu:0"):
gpu_optional_with_value = optional_ops._OptionalImpl(
array_ops.identity(optional_with_value._variant_tensor),
- optional_with_value.output_shapes, optional_with_value.output_types,
- optional_with_value.output_classes)
+ optional_with_value.value_structure)
gpu_optional_none = optional_ops._OptionalImpl(
array_ops.identity(optional_none._variant_tensor),
- optional_none.output_shapes, optional_none.output_types,
- optional_none.output_classes)
+ optional_none.value_structure)
gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
gpu_optional_with_value_values = gpu_optional_with_value.get_value()
@@ -148,14 +119,101 @@ class OptionalTest(test.TestCase):
self.evaluate(gpu_optional_with_value_values))
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
- def testIteratorGetNextAsOptional(self):
- ds = dataset_ops.Dataset.range(3)
+ def _assertElementValueEqual(self, expected, actual):
+ if isinstance(expected, dict):
+ self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
+ for k in expected.keys():
+ self._assertElementValueEqual(expected[k], actual[k])
+ elif isinstance(expected, sparse_tensor.SparseTensorValue):
+ self.assertAllEqual(expected.indices, actual.indices)
+ self.assertAllEqual(expected.values, actual.values)
+ self.assertAllEqual(expected.dense_shape, actual.dense_shape)
+ else:
+ self.assertAllEqual(expected, actual)
+
+ # pylint: disable=g-long-lambda
+ @parameterized.named_parameters(
+ ("Tensor", lambda: constant_op.constant(37.0),
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", lambda: sparse_tensor.SparseTensor(
+ indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
+ dense_shape=[1]),
+ structure.SparseTensorStructure(dtypes.int32, [1])),
+ ("Nest", lambda: {
+ "a": constant_op.constant(37.0),
+ "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.TensorStructure(dtypes.string, [1]),
+ structure.TensorStructure(dtypes.string, []))})),
+ ("Optional", lambda: optional_ops.Optional.from_value(37.0),
+ optional_ops.OptionalStructure(
+ structure.TensorStructure(dtypes.float32, []))),
+ )
+ def testOptionalStructure(self, tf_value_fn, expected_value_structure):
+ tf_value = tf_value_fn()
+ opt = optional_ops.Optional.from_value(tf_value)
+
+ self.assertTrue(
+ expected_value_structure.is_compatible_with(opt.value_structure))
+ self.assertTrue(
+ opt.value_structure.is_compatible_with(expected_value_structure))
+
+ opt_structure = structure.Structure.from_value(opt)
+ self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
+ self.assertTrue(opt_structure.is_compatible_with(opt_structure))
+ self.assertTrue(opt_structure._value_structure.is_compatible_with(
+ expected_value_structure))
+ self.assertEqual([dtypes.variant], opt_structure._flat_types)
+ self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)
+
+ # All OptionalStructure objects are not compatible with a non-optional
+ # value.
+ non_optional_structure = structure.Structure.from_value(
+ constant_op.constant(42.0))
+ self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))
+
+ # Assert that the optional survives a round-trip via _from_tensor_list()
+ # and _to_tensor_list().
+ round_trip_opt = opt_structure._from_tensor_list(
+ opt_structure._to_tensor_list(opt))
+ if isinstance(tf_value, optional_ops.Optional):
+ self.assertEqual(
+ self.evaluate(tf_value.get_value()),
+ self.evaluate(round_trip_opt.get_value().get_value()))
+ else:
+ self.assertEqual(
+ self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))
+
+ @parameterized.named_parameters(
+ ("Tensor", np.array([1, 2, 3], dtype=np.int32),
+ lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
+ ("SparseTensor", sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]],
+ values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
+ lambda: sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
+ False),
+ ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
+ "b": sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]],
+ values=np.array([-1., 1.], dtype=np.float32),
+ dense_shape=[2, 2])},
+ lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
+ "b": sparse_tensor.SparseTensor(
+ indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
+ dense_shape=[2, 2])}, False),
+ )
+ def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu):
+ if not works_on_gpu and test.is_gpu_available():
+ self.skipTest("Test case not yet supported on GPU.")
+ ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
iterator = ds.make_initializable_iterator()
next_elem = iterator_ops.get_next_as_optional(iterator)
- self.assertTrue(isinstance(next_elem, optional_ops.Optional))
- self.assertEqual(ds.output_types, next_elem.output_types)
- self.assertEqual(ds.output_shapes, next_elem.output_shapes)
- self.assertEqual(ds.output_classes, next_elem.output_classes)
+ self.assertIsInstance(next_elem, optional_ops.Optional)
+ self.assertTrue(
+ next_elem.value_structure.is_compatible_with(
+ structure.Structure.from_value(tf_value_fn())))
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
with self.cached_session() as sess:
@@ -169,10 +227,10 @@ class OptionalTest(test.TestCase):
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
sess.run(iterator.initializer)
- for i in range(3):
+ for _ in range(3):
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
- self.assertEqual(i, elem_value)
+ self._assertElementValueEqual(np_value, elem_value)
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
new file mode 100644
index 0000000000..fd4348426d
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
@@ -0,0 +1,295 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
+ ("12", 20, 14, 7, 1, False),
+ ("13", 20, 17, 9, 1, False),
+ ("14", 20, 14, 14, 1, False),
+ ("15", 20, 10, 14, 1, False),
+ ("16", 20, 14, 19, 1, False),
+ ("17", 20, 4, 1, 2, False),
+ ("18", 20, 2, 1, 6, False),
+ ("19", 20, 4, 7, 2, False),
+ ("20", 20, 2, 7, 6, False),
+ ("21", 1, 10, 4, 1, False),
+ ("22", 0, 10, 4, 1, False),
+ )
+ def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
+ """Tests a dataset that slides a window its input elements."""
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+ stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+ drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ def _flat_map_fn(x, y, z):
+ return dataset_ops.Dataset.zip((x.batch(batch_size=size_t),
+ y.batch(batch_size=size_t),
+ z.batch(batch_size=size_t)))
+
+ iterator = dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn).repeat(count).window(
+ size=size_t,
+ shift=shift_t,
+ stride=stride_t,
+ drop_remainder=drop_remainder_t).flat_map(
+ _flat_map_fn).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ [t.shape.as_list() for t in get_next])
+
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ size_t: size,
+ shift_t: shift,
+ stride_t: stride,
+ drop_remainder_t: drop_remainder
+ })
+ num_full_batches = max(
+ 0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
+ for i in range(num_full_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(size):
+ self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
+ result_component[j])
+ if not drop_remainder:
+ num_partial_batches = (count * 7) // shift + (
+ (count * 7) % shift > 0) - num_full_batches
+ for i in range(num_partial_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ remaining = (count * 7) - ((num_full_batches + i) * shift)
+ num_elements = remaining // stride + ((remaining % stride) > 0)
+ for j in range(num_elements):
+ self.assertAllEqual(
+ component[((num_full_batches + i) * shift + j * stride) % 7]
+ **2, result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
+ )
+ def testWindowDatasetInvalid(self, count, size, shift, stride):
+ count_t = array_ops.placeholder(dtypes.int64, shape=[])
+ size_t = array_ops.placeholder(dtypes.int64, shape=[])
+ shift_t = array_ops.placeholder(dtypes.int64, shape=[])
+ stride_t = array_ops.placeholder(dtypes.int64, shape=[])
+
+ iterator = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(
+ count_t).window(
+ size=size_t, shift=shift_t,
+ stride=stride_t).flat_map(lambda x: x.batch(batch_size=size_t)
+ ).make_initializable_iterator()
+ init_op = iterator.initializer
+
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(
+ init_op,
+ feed_dict={
+ count_t: count,
+ size_t: size,
+ shift_t: shift,
+ stride_t: stride
+ })
+
+ def assertSparseValuesEqual(self, a, b):
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def testWindowSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=5, shift=3, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ num_batches = (10 - 5) // 3 + 1
+ for i in range(num_batches):
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
+ dense_shape=[5, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testWindowSparseWithDifferentDenseShapes(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=array_ops.expand_dims(
+ math_ops.range(i, dtype=dtypes.int64), 1),
+ values=array_ops.fill([math_ops.to_int32(i)], i),
+ dense_shape=[i])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=5, shift=3, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ num_batches = (10 - 5) // 3 + 1
+ for i in range(num_batches):
+ actual = sess.run(get_next)
+ expected_indices = []
+ expected_values = []
+ for j in range(5):
+ for k in range(i * 3 + j):
+ expected_indices.append([j, k])
+ expected_values.append(i * 3 + j)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=expected_indices,
+ values=expected_values,
+ dense_shape=[5, i * 3 + 5 - 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testNestedWindowSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=4, shift=2,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window(
+ size=3, shift=1, drop_remainder=True).flat_map(
+ lambda x: x.batch(batch_size=3)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ # Slide: 1st batch.
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
+ values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
+ dense_shape=[3, 4, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ # Slide: 2nd batch.
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
+ values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
+ dense_shape=[3, 4, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testWindowShapeError(self):
+
+ def generator():
+ yield [1.0, 2.0, 3.0]
+ yield [4.0, 5.0, 6.0]
+ yield [7.0, 8.0, 9.0, 10.0]
+
+ iterator = dataset_ops.Dataset.from_generator(
+ generator, dtypes.float32, output_shapes=[None]).window(
+ size=3, shift=1).flat_map(
+ lambda x: x.batch(batch_size=3)).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r"Cannot batch tensors with different shapes in component 0. "
+ r"First element had shape \[3\] and element 2 had shape \[4\]."):
+ sess.run(next_element)
+
+ def testWindowIgnoreErrors(self):
+ input_values = np.float32([1., np.nan, 2., np.nan, 3.])
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+ lambda x: array_ops.check_numerics(x, "message")).window(
+ size=2, shift=2, stride=2,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2))
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual(np.float32([1., 2.]), sess.run(get_next))
+ self.assertAllEqual(np.float32([2., 3.]), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 57517afae8..76bf2470b1 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -19,6 +19,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:script_ops",
+ "//tensorflow/python:smart_cond",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
@@ -63,6 +64,7 @@ py_library(
"//tensorflow/python/compat",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:structure",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
],
@@ -77,8 +79,23 @@ py_library(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/util:structure",
+ ],
+)
+
+py_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
],
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index c985e00dd1..ac87a451b1 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -80,6 +80,12 @@ class Dataset(object):
"""
raise NotImplementedError("Dataset._as_variant_tensor")
+ @abc.abstractmethod
+ def _inputs(self):
+ """Returns a list of the input datasets of the dataset."""
+
+ raise NotImplementedError("Dataset._inputs")
+
def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -1009,6 +1015,23 @@ class Dataset(object):
def flat_map(self, map_func):
"""Maps `map_func` across this dataset and flattens the result.
+ Use `flat_map` if you want to make sure that the order of your dataset
+ stays the same. For example, to flatten a dataset of batches into a
+ dataset of their elements:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset. '[...]' represents a tensor.
+ a = {[1,2,3,4,5], [6,7,8,9], [10]}
+
+ a.flat_map(lambda x: Dataset.from_tensor_slices(x)) ==
+ {[1,2,3,4,5,6,7,8,9,10]}
+ ```
+
+ `tf.data.Dataset.interleave()` is a generalization of `flat_map`, since
+ `flat_map` produces the same output as
+ `tf.data.Dataset.interleave(cycle_length=1)`
+
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@@ -1043,7 +1066,7 @@ class Dataset(object):
elements are produced. `cycle_length` controls the number of input elements
that are processed concurrently. If you set `cycle_length` to 1, this
transformation will handle one input element at a time, and will produce
- identical results = to `tf.data.Dataset.flat_map`. In general,
+ identical results to `tf.data.Dataset.flat_map`. In general,
this transformation will apply `map_func` to `cycle_length` input elements,
open iterators on the returned `Dataset` objects, and cycle through them
producing `block_length` consecutive elements from each iterator, and
@@ -1115,7 +1138,7 @@ class Dataset(object):
return FilterDataset(self, predicate)
def apply(self, transformation_func):
- """Apply a transformation function to this dataset.
+ """Applies a transformation function to this dataset.
`apply` enables chaining of custom `Dataset` transformations, which are
represented as functions that take one `Dataset` argument and return a
@@ -1131,7 +1154,7 @@ class Dataset(object):
Args:
transformation_func: A function that takes one `Dataset` argument and
- returns a `Dataset`.
+ returns a `Dataset`.
Returns:
Dataset: The `Dataset` returned by applying `transformation_func` to this
@@ -1140,10 +1163,68 @@ class Dataset(object):
dataset = transformation_func(self)
if not isinstance(dataset, Dataset):
raise TypeError("`transformation_func` must return a Dataset.")
+ dataset._input_datasets = [self] # pylint: disable=protected-access
return dataset
+ def window(self, size, shift=None, stride=1, drop_remainder=False):
+ """Combines input elements into a dataset of windows.
+
+ Each window is a dataset itself and contains `size` elements (or
+ possibly fewer if there are not enough input elements to fill the window
+ and `drop_remainder` evaluates to false).
+
+ The `stride` argument determines the stride of the input elements,
+ and the `shift` argument determines the shift of the window.
+
+ For example:
+ - `tf.data.Dataset.range(7).window(2)` produces
+ `{{0, 1}, {2, 3}, {4, 5}, {6}}`
+ - `tf.data.Dataset.range(7).window(3, 2, 1, True)` produces
+ `{{0, 1, 2}, {2, 3, 4}, {4, 5, 6}}`
+ - `tf.data.Dataset.range(7).window(3, 1, 2, True)` produces
+ `{{0, 2, 4}, {1, 3, 5}, {2, 4, 6}}`
+
+ Args:
+ size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements
+ of the input dataset to combine into a window.
+ shift: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ forward shift of the sliding window in each iteration. Defaults to
+ `size`.
+ stride: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ stride of the input elements in the sliding window.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether a window should be dropped in case its size is smaller than
+ `window_size`.
+
+ Returns:
+ Dataset: A `Dataset` of windows, each of which is a nested `Dataset` with
+ the same structure as this dataset, but a finite subsequence of its
+ elements.
+ """
+ if shift is None:
+ shift = size
+ return WindowDataset(self, size, shift, stride, drop_remainder)
+
+
+class DatasetSource(Dataset):
+ """Abstract class representing a dataset with no inputs."""
+
+ def _inputs(self):
+ return []
+
+
+class UnaryDataset(Dataset):
+ """Abstract class representing a dataset with one input."""
+
+ def __init__(self, input_dataset):
+ super(UnaryDataset, self).__init__()
+ self._input_dataset = input_dataset
+
+ def _inputs(self):
+ return [self._input_dataset]
+
-class TensorDataset(Dataset):
+class TensorDataset(DatasetSource):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
def __init__(self, tensors):
@@ -1183,7 +1264,7 @@ class TensorDataset(Dataset):
return self._output_types
-class TensorSliceDataset(Dataset):
+class TensorSliceDataset(DatasetSource):
"""A `Dataset` of slices from a nested structure of tensors."""
def __init__(self, tensors):
@@ -1227,7 +1308,7 @@ class TensorSliceDataset(Dataset):
return self._output_types
-class SparseTensorSliceDataset(Dataset):
+class SparseTensorSliceDataset(DatasetSource):
"""A `Dataset` that splits a rank-N `tf.SparseTensor` into its rows."""
def __init__(self, sparse_tensor):
@@ -1328,6 +1409,9 @@ class _VariantDataset(Dataset):
def _as_variant_tensor(self):
return self._dataset_variant
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return self._structure.output_classes
@@ -1568,7 +1652,7 @@ def flat_structure(dataset):
}
-class _GeneratorDataset(Dataset):
+class _GeneratorDataset(DatasetSource):
"""A `Dataset` that generates elements by invoking a function."""
def __init__(self, init_args, init_func, next_func, finalize_func):
@@ -1669,6 +1753,9 @@ class ZipDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return nest.flatten(self._datasets)
+
@property
def output_classes(self):
return nest.pack_sequence_as(
@@ -1704,6 +1791,7 @@ class ConcatenateDataset(Dataset):
raise TypeError(
"Two datasets to concatenate have different classes %s and %s" %
(input_dataset.output_classes, dataset_to_concatenate.output_classes))
+ self._input_datasets = [input_dataset, dataset_to_concatenate]
def _as_variant_tensor(self):
# pylint: disable=protected-access
@@ -1713,6 +1801,9 @@ class ConcatenateDataset(Dataset):
**flat_structure(self))
# pylint: enable=protected-access
+ def _inputs(self):
+ return [self._input_dataset, self._dataset_to_concatenate]
+
@property
def output_classes(self):
return self._input_dataset.output_classes
@@ -1731,12 +1822,12 @@ class ConcatenateDataset(Dataset):
return self._input_dataset.output_types
-class RepeatDataset(Dataset):
+class RepeatDataset(UnaryDataset):
"""A `Dataset` that repeats its input several times."""
def __init__(self, input_dataset, count):
"""See `Dataset.repeat()` for details."""
- super(RepeatDataset, self).__init__()
+ super(RepeatDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if count is None:
self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
@@ -1763,7 +1854,7 @@ class RepeatDataset(Dataset):
return self._input_dataset.output_types
-class RangeDataset(Dataset):
+class RangeDataset(DatasetSource):
"""A `Dataset` of a step separated range of values."""
def __init__(self, *args):
@@ -1811,12 +1902,12 @@ class RangeDataset(Dataset):
return dtypes.int64
-class CacheDataset(Dataset):
+class CacheDataset(UnaryDataset):
"""A `Dataset` that caches elements of its input."""
def __init__(self, input_dataset, filename):
"""See `Dataset.cache()` for details."""
- super(CacheDataset, self).__init__()
+ super(CacheDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._filename = ops.convert_to_tensor(
filename, dtype=dtypes.string, name="filename")
@@ -1840,7 +1931,7 @@ class CacheDataset(Dataset):
return self._input_dataset.output_types
-class ShuffleDataset(Dataset):
+class ShuffleDataset(UnaryDataset):
"""A `Dataset` that randomly shuffles the elements of its input."""
def __init__(self,
@@ -1868,7 +1959,7 @@ class ShuffleDataset(Dataset):
Raises:
ValueError: if invalid arguments are provided.
"""
- super(ShuffleDataset, self).__init__()
+ super(ShuffleDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
@@ -1900,12 +1991,12 @@ class ShuffleDataset(Dataset):
return self._input_dataset.output_types
-class TakeDataset(Dataset):
+class TakeDataset(UnaryDataset):
"""A `Dataset` containing the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.take()` for details."""
- super(TakeDataset, self).__init__()
+ super(TakeDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -1928,12 +2019,12 @@ class TakeDataset(Dataset):
return self._input_dataset.output_types
-class SkipDataset(Dataset):
+class SkipDataset(UnaryDataset):
"""A `Dataset` skipping the first `count` elements from its input."""
def __init__(self, input_dataset, count):
"""See `Dataset.skip()` for details."""
- super(SkipDataset, self).__init__()
+ super(SkipDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._count = ops.convert_to_tensor(count, dtype=dtypes.int64, name="count")
@@ -1956,12 +2047,12 @@ class SkipDataset(Dataset):
return self._input_dataset.output_types
-class BatchDataset(Dataset):
+class BatchDataset(UnaryDataset):
"""A `Dataset` that batches contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, drop_remainder):
"""See `Dataset.batch()` for details."""
- super(BatchDataset, self).__init__()
+ super(BatchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
@@ -2110,13 +2201,13 @@ def _default_padding(input_dataset):
return nest.map_structure(make_zero, input_dataset.output_types)
-class PaddedBatchDataset(Dataset):
+class PaddedBatchDataset(UnaryDataset):
"""A `Dataset` that batches and pads contiguous elements from its input."""
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values,
drop_remainder):
"""See `Dataset.batch()` for details."""
- super(PaddedBatchDataset, self).__init__()
+ super(PaddedBatchDataset, self).__init__(input_dataset)
if sparse.any_sparse(input_dataset.output_classes):
# TODO(b/63669786): support batching of sparse tensors
raise TypeError(
@@ -2216,12 +2307,12 @@ def _warn_if_collections(transformation_name):
% transformation_name)
-class MapDataset(Dataset):
+class MapDataset(UnaryDataset):
"""A `Dataset` that maps a function over elements in its input."""
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
- super(MapDataset, self).__init__()
+ super(MapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
@@ -2282,12 +2373,12 @@ class ParallelMapDataset(MapDataset):
# pylint: enable=protected-access
-class FlatMapDataset(Dataset):
+class FlatMapDataset(UnaryDataset):
"""A `Dataset` that maps a function over its input and flattens the result."""
def __init__(self, input_dataset, map_func):
"""See `Dataset.flat_map()` for details."""
- super(FlatMapDataset, self).__init__()
+ super(FlatMapDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
@@ -2378,12 +2469,12 @@ class ParallelInterleaveDataset(FlatMapDataset):
return "Dataset.interleave()"
-class FilterDataset(Dataset):
+class FilterDataset(UnaryDataset):
"""A `Dataset` that filters its input according to a predicate function."""
def __init__(self, input_dataset, predicate):
"""See `Dataset.filter()` for details."""
- super(FilterDataset, self).__init__()
+ super(FilterDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
wrapped_func = StructuredFunctionWrapper(
predicate, "Dataset.filter()", input_dataset)
@@ -2413,12 +2504,12 @@ class FilterDataset(Dataset):
return self._input_dataset.output_types
-class PrefetchDataset(Dataset):
+class PrefetchDataset(UnaryDataset):
"""A `Dataset` that asynchronously prefetches its input."""
def __init__(self, input_dataset, buffer_size):
"""See `Dataset.prefetch()` for details."""
- super(PrefetchDataset, self).__init__()
+ super(PrefetchDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
if buffer_size is None:
buffer_size = -1 # This is the sentinel for auto-tuning.
@@ -2442,3 +2533,53 @@ class PrefetchDataset(Dataset):
@property
def output_types(self):
return self._input_dataset.output_types
+
+
+class WindowDataset(UnaryDataset):
+ """A dataset that creates window datasets from the input elements."""
+
+ def __init__(self, input_dataset, size, shift, stride, drop_remainder):
+ """See `window_dataset()` for more details."""
+ super(WindowDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._size = ops.convert_to_tensor(size, dtype=dtypes.int64, name="size")
+ self._shift = ops.convert_to_tensor(shift, dtype=dtypes.int64, name="shift")
+ self._stride = ops.convert_to_tensor(
+ stride, dtype=dtypes.int64, name="stride")
+ self._drop_remainder = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+ self._output_classes = nest.pack_sequence_as(
+ input_dataset.output_classes,
+ [
+ _NestedDatasetComponent( # pylint: disable=protected-access
+ output_classes=output_class,
+ output_shapes=output_shape,
+ output_types=output_type)
+ for output_class, output_shape, output_type in zip(
+ nest.flatten(input_dataset.output_classes),
+ nest.flatten(input_dataset.output_shapes),
+ nest.flatten(input_dataset.output_types))
+ ])
+ self._output_shapes = self._output_classes
+ self._output_types = self._output_classes
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._size,
+ self._shift,
+ self._stride,
+ self._drop_remainder,
+ **flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 8f8e026df9..cae00cdbfc 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -24,6 +24,7 @@ from tensorflow.python.compat import compat
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -85,10 +86,10 @@ class Iterator(checkpointable.CheckpointableBase):
initializer: A `tf.Operation` that should be run to initialize this
iterator.
output_types: A nested structure of `tf.DType` objects corresponding to
- each component of an element of this dataset.
+ each component of an element of this iterator.
output_shapes: A nested structure of `tf.TensorShape` objects
- corresponding to each component of an element of this dataset.
- output_classes: A nested structure of Python `type` object corresponding
+ corresponding to each component of an element of this iterator.
+ output_classes: A nested structure of Python `type` objects corresponding
to each component of an element of this iterator.
"""
self._iterator_resource = iterator_resource
@@ -670,6 +671,6 @@ def get_next_as_optional(iterator):
output_shapes=nest.flatten(
sparse.as_dense_shapes(iterator.output_shapes,
iterator.output_classes))),
- output_shapes=iterator.output_shapes,
- output_types=iterator.output_types,
- output_classes=iterator.output_classes)
+ structure.Structure._from_legacy_structure(iterator.output_types,
+ iterator.output_shapes,
+ iterator.output_classes))
diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py
new file mode 100644
index 0000000000..b7d3aac206
--- /dev/null
+++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py
@@ -0,0 +1,231 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
+
+
+class _PerDeviceGenerator(dataset_ops.Dataset):
+ """A `dummy` generator dataset."""
+
+ def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
+ source_device, target_device, output_shapes, output_types,
+ output_classes):
+ self._target_device = target_device
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+ self._output_classes = output_classes
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes))
+
+ multi_device_iterator_string_handle = (
+ gen_dataset_ops.multi_device_iterator_to_string_handle(
+ multi_device_iterator_resource))
+
+ @function.Defun()
+ def _init_func():
+ return multi_device_iterator_string_handle
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ multi_device_iterator = (
+ gen_dataset_ops.multi_device_iterator_from_string_handle(
+ string_handle=string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+ return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
+ multi_device_iterator=multi_device_iterator,
+ shard_num=shard_num,
+ incarnation_id=incarnation_id,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(unused_string_handle):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ def _inputs(self):
+ # TODO(b/116506223): Determine which datasets should be used as inputs here.
+ return []
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+class MultiDeviceIterator(object):
+ """An iterator over multiple devices.
+
+ @compatibility(eager)
+ MultiDeviceIterator isn't currently supported in Eager mode but support is
+ coming soon.
+ @end_compatibility
+ """
+
+ def __init__(self,
+ dataset,
+ devices,
+ max_buffer_size=1,
+ prefetch_buffer_size=1,
+ source_device="/cpu:0"):
+ """Constructs a MultiDeviceIterator.
+
+ Args:
+ dataset: The input dataset to be iterated over.
+ devices: The list of devices to fetch data to.
+ max_buffer_size: Maximum size of the host side per device buffer to keep.
+ prefetch_buffer_size: if > 1, then we setup a buffer on each device
+ to prefetch into.
+ source_device: The host device to place the `dataset` on.
+
+ Raises:
+ RuntimeError: If run in Eager mode.
+ """
+ if context.executing_eagerly():
+ # TODO(rohanj): Fix this. Tracking bug: b/116467184
+ raise RuntimeError("MultiDeviceIterator is not currently supported in "
+ "Eager mode.")
+ self._dataset = dataset
+ self._devices = devices
+ self._source_device = source_device
+ self._source_device_tensor = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._dataset.output_shapes,
+ self._dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._dataset.output_types,
+ self._dataset.output_classes))
+
+ # Create the MultiDeviceIterator.
+ with ops.device(self._source_device):
+ self._multi_device_iterator_resource = (
+ gen_dataset_ops.multi_device_iterator(
+ devices=self._devices,
+ shared_name="",
+ container="",
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+
+ # The incarnation ID is used to ensure consistency between the per-device
+ # iterators and the multi-device iterator.
+ self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
+ self._dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._multi_device_iterator_resource,
+ max_buffer_size=max_buffer_size)
+
+ # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
+ # initialize the device side of the pipeline. This would allow the
+ # MultiDeviceIterator to choose, for example, to move some transformations
+ # into the device side from its input. It might be useful in rewriting.
+ # Create the per device iterators.
+ self._device_iterators = []
+ i = 0
+ for device in self._devices:
+ ds = _PerDeviceGenerator(
+ i, self._multi_device_iterator_resource, self._incarnation_id,
+ self._source_device_tensor, device, self._dataset.output_shapes,
+ self._dataset.output_types, self._dataset.output_classes)
+ if prefetch_buffer_size > 0:
+ ds = ds.prefetch(prefetch_buffer_size)
+ with ops.device(device):
+ self._device_iterators.append(ds.make_initializable_iterator())
+ i += 1
+
+ device_iterator_initializers = [
+ iterator.initializer for iterator in self._device_iterators
+ ]
+ self._initializer = control_flow_ops.group(*device_iterator_initializers)
+
+ def get_next(self):
+ result = []
+ i = 0
+ for device in self._devices:
+ with ops.device(device):
+ result.append(self._device_iterators[i].get_next())
+ i += 1
+ return result
+
+ @property
+ def initializer(self):
+ return self._initializer
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index b75b98dc72..3bbebd7878 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -19,11 +19,9 @@ from __future__ import print_function
import abc
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_dataset_ops
@@ -67,36 +65,14 @@ class Optional(object):
raise NotImplementedError("Optional.get_value()")
@abc.abstractproperty
- def output_classes(self):
- """Returns the class of each component of this optional.
-
- The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
- Returns:
- A nested structure of Python `type` objects corresponding to each
- component of this optional.
- """
- raise NotImplementedError("Optional.output_classes")
-
- @abc.abstractproperty
- def output_shapes(self):
- """Returns the shape of each component of this optional.
-
- Returns:
- A nested structure of `tf.TensorShape` objects corresponding to each
- component of this optional.
- """
- raise NotImplementedError("Optional.output_shapes")
-
- @abc.abstractproperty
- def output_types(self):
- """Returns the type of each component of this optional.
+ def value_structure(self):
+ """The structure of the components of this optional.
Returns:
- A nested structure of `tf.DType` objects corresponding to each component
- of this optional.
+ A `Structure` object representing the structure of the components of this
+ optional.
"""
- raise NotImplementedError("Optional.output_types")
+ raise NotImplementedError("Optional.value_structure")
@staticmethod
def from_value(value):
@@ -108,48 +84,30 @@ class Optional(object):
Returns:
An `Optional` that wraps `value`.
"""
- # TODO(b/110122868): Consolidate this destructuring logic with the
- # similar code in `Dataset.from_tensors()`.
with ops.name_scope("optional") as scope:
with ops.name_scope("value"):
- value = nest.pack_sequence_as(value, [
- sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
- t, name="component_%d" % i)
- for i, t in enumerate(nest.flatten(value))
- ])
-
- encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
- output_classes = sparse.get_classes(value)
- output_shapes = nest.pack_sequence_as(
- value, [t.get_shape() for t in nest.flatten(value)])
- output_types = nest.pack_sequence_as(
- value, [t.dtype for t in nest.flatten(value)])
+ value_structure = structure.Structure.from_value(value)
+ encoded_value = value_structure._to_tensor_list(value) # pylint: disable=protected-access
return _OptionalImpl(
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
- output_shapes, output_types, output_classes)
+ value_structure)
@staticmethod
- def none_from_structure(output_shapes, output_types, output_classes):
+ def none_from_structure(value_structure):
"""Returns an `Optional` that has no value.
- NOTE: This method takes arguments that define the structure of the value
+ NOTE: This method takes an argument that defines the structure of the value
that would be contained in the returned `Optional` if it had a value.
Args:
- output_shapes: A nested structure of `tf.TensorShape` objects
- corresponding to each component of this optional.
- output_types: A nested structure of `tf.DType` objects corresponding to
- each component of this optional.
- output_classes: A nested structure of Python `type` objects corresponding
- to each component of this optional.
+ value_structure: A `Structure` object representing the structure of the
+ components of this optional.
Returns:
An `Optional` that has no value.
"""
- return _OptionalImpl(gen_dataset_ops.optional_none(), output_shapes,
- output_types, output_classes)
+ return _OptionalImpl(gen_dataset_ops.optional_none(), value_structure)
class _OptionalImpl(Optional):
@@ -159,20 +117,9 @@ class _OptionalImpl(Optional):
`Optional.__init__()` in the public API.
"""
- def __init__(self, variant_tensor, output_shapes, output_types,
- output_classes):
- # TODO(b/110122868): Consolidate the structure validation logic with the
- # similar logic in `Iterator.from_structure()` and
- # `Dataset.from_generator()`.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- output_shapes = nest.map_structure_up_to(
- output_types, tensor_shape.as_shape, output_shapes)
- nest.assert_same_structure(output_types, output_shapes)
- nest.assert_same_structure(output_types, output_classes)
+ def __init__(self, variant_tensor, value_structure):
self._variant_tensor = variant_tensor
- self._output_shapes = output_shapes
- self._output_types = output_types
- self._output_classes = output_classes
+ self._value_structure = value_structure
def has_value(self, name=None):
return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
@@ -182,28 +129,55 @@ class _OptionalImpl(Optional):
# in `Iterator.get_next()` and `StructuredFunctionWrapper`.
with ops.name_scope(name, "OptionalGetValue",
[self._variant_tensor]) as scope:
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(
- self._output_types,
- gen_dataset_ops.optional_get_value(
- self._variant_tensor,
- name=scope,
- output_types=nest.flatten(
- sparse.as_dense_types(self._output_types,
- self._output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_shapes(self._output_shapes,
- self._output_classes)))),
- self._output_types, self._output_shapes, self._output_classes)
+ # pylint: disable=protected-access
+ return self._value_structure._from_tensor_list(
+ gen_dataset_ops.optional_get_value(
+ self._variant_tensor,
+ name=scope,
+ output_types=self._value_structure._flat_types,
+ output_shapes=self._value_structure._flat_shapes))
@property
- def output_classes(self):
- return self._output_classes
+ def value_structure(self):
+ return self._value_structure
+
+
+class OptionalStructure(structure.Structure):
+ """Represents an optional potentially containing a structured value."""
+
+ def __init__(self, value_structure):
+ self._value_structure = value_structure
@property
- def output_shapes(self):
- return self._output_shapes
+ def _flat_shapes(self):
+ return [tensor_shape.scalar()]
@property
- def output_types(self):
- return self._output_types
+ def _flat_types(self):
+ return [dtypes.variant]
+
+ def is_compatible_with(self, other):
+ # pylint: disable=protected-access
+ return (isinstance(other, OptionalStructure) and
+ self._value_structure.is_compatible_with(other._value_structure))
+
+ def _to_tensor_list(self, value):
+ return [value._variant_tensor] # pylint: disable=protected-access
+
+ def _from_tensor_list(self, flat_value):
+ if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
+ not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "OptionalStructure corresponds to a single tf.variant scalar.")
+ # pylint: disable=protected-access
+ return _OptionalImpl(flat_value[0], self._value_structure)
+
+ @staticmethod
+ def from_value(value):
+ return OptionalStructure(value.value_structure)
+
+
+# pylint: disable=protected-access
+structure.Structure._register_custom_converter(Optional,
+ OptionalStructure.from_value)
+# pylint: enable=protected-access
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index 066e09969c..b0f26631f9 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -61,6 +61,9 @@ class TextLineDataset(dataset_ops.Dataset):
return gen_dataset_ops.text_line_dataset(
self._filenames, self._compression_type, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
@@ -105,6 +108,9 @@ class _TFRecordDataset(dataset_ops.Dataset):
return gen_dataset_ops.tf_record_dataset(
self._filenames, self._compression_type, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
@@ -224,6 +230,9 @@ class TFRecordDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
return self._impl._as_variant_tensor() # pylint: disable=protected-access
+ def _inputs(self):
+ return self._impl._inputs() # pylint: disable=protected-access
+
@property
def output_classes(self):
return self._impl.output_classes
@@ -278,6 +287,9 @@ class FixedLengthRecordDataset(dataset_ops.Dataset):
self._filenames, self._header_bytes, self._record_bytes,
self._footer_bytes, self._buffer_size)
+ def _inputs(self):
+ return []
+
@property
def output_classes(self):
return ops.Tensor
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index 9d621fcd30..e5abc654da 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -96,37 +96,11 @@ def _yield_value(iterable):
yield value
-def is_sequence(seq):
- """Returns a true if `seq` is a Sequence or dict (except strings/lists).
+# See the swig file (../../util/util.i) for documentation.
+is_sequence = _pywrap_tensorflow.IsSequenceForData
- NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
- which *does* treat a Python list as a sequence. For ergonomic
- reasons, `tf.data` users would prefer to treat lists as
- implicit `tf.Tensor` objects, and dicts as (nested) sequences.
-
- Args:
- seq: an input sequence.
-
- Returns:
- True if the sequence is a not a string or list and is a
- collections.Sequence.
- """
- return _pywrap_tensorflow.IsSequenceForData(seq)
-
-
-def flatten(nest):
- """Returns a flat sequence from a given nested structure.
-
- If `nest` is not a sequence, this returns a single-element list: `[nest]`.
-
- Args:
- nest: an arbitrarily nested structure or a scalar object.
- Note, numpy arrays are considered scalars.
-
- Returns:
- A Python list, the flattened version of the input.
- """
- return _pywrap_tensorflow.FlattenForData(nest)
+# See the swig file (../../util/util.i) for documentation.
+flatten = _pywrap_tensorflow.FlattenForData
def assert_same_structure(nest1, nest2, check_types=True):
diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py
index c5764b8dfe..a90ca258c0 100644
--- a/tensorflow/python/data/util/structure.py
+++ b/tensorflow/python/data/util/structure.py
@@ -28,6 +28,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import sparse_ops
+_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {}
+
+
class Structure(object):
"""Represents structural information, such as type and shape, about a value.
@@ -64,12 +67,10 @@ class Structure(object):
raise NotImplementedError("Structure._flat_shapes")
@abc.abstractmethod
- def is_compatible_with(self, value):
- """Returns `True` if `value` is compatible with this structure.
+ def is_compatible_with(self, other):
+ """Returns `True` if `other` is compatible with this structure.
- A value `value` is compatible with a structure `s` if
- `Structure.from_value(value)` would return a structure `t` that is a
- "subtype" of `s`. A structure `t` is a "subtype" of `s` if:
+ A structure `t` is a "subtype" of `s` if:
* `s` and `t` are instances of the same `Structure` subclass.
* The nested structures (if any) of `s` and `t` are the same, according to
@@ -83,10 +84,10 @@ class Structure(object):
`tf.TensorShape.is_compatible_with`.
Args:
- value: A potentially structured value.
+ other: A `Structure`.
Returns:
- `True` if `value` matches this structure, otherwise `False`.
+ `True` if `other` is a subtype of this structure, otherwise `False`.
"""
raise NotImplementedError("Structure.is_compatible_with()")
@@ -98,7 +99,7 @@ class Structure(object):
`self._flat_types` to represent structured values in lower level APIs
(such as plain TensorFlow operations) that do not understand structure.
- Requires: `self.is_compatible_with(value)`.
+ Requires: `self.is_compatible_with(Structure.from_value(value))`.
Args:
value: A value with compatible structure.
@@ -137,9 +138,8 @@ class Structure(object):
TypeError: If a structure cannot be built for `value`, because its type
or one of its component types is not supported.
"""
-
- # TODO(b/110122868): Add support for custom types, Dataset, and Optional
- # to this method.
+ # TODO(b/110122868): Add support for custom types and Dataset to this
+ # method.
if isinstance(
value,
(sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
@@ -147,12 +147,76 @@ class Structure(object):
elif isinstance(value, (tuple, dict)):
return NestedStructure.from_value(value)
else:
+ for converter_type, converter_fn in (
+ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()):
+ if isinstance(value, converter_type):
+ return converter_fn(value)
try:
tensor = ops.convert_to_tensor(value)
except (ValueError, TypeError):
raise TypeError("Could not build a structure for %r" % value)
return TensorStructure.from_value(tensor)
+ @staticmethod
+ def _from_legacy_structure(output_types, output_shapes, output_classes):
+ """Returns a `Structure` that represents the given legacy structure.
+
+ This method provides a way to convert from the existing `Dataset` and
+ `Iterator` structure-related properties to a `Structure` object.
+
+ TODO(b/110122868): Remove this method once `Structure` is used throughout
+ `tf.data`.
+
+ Args:
+ output_types: A nested structure of `tf.DType` objects corresponding to
+ each component of a structured value.
+ output_shapes: A nested structure of `tf.TensorShape` objects
+ corresponding to each component a structured value.
+ output_classes: A nested structure of Python `type` objects corresponding
+ to each component of a structured value.
+
+ Returns:
+ A `Structure`.
+
+ Raises:
+ TypeError: If a structure cannot be built the arguments, because one of
+ the component classes in `output_classes` is not supported.
+ """
+ flat_types = nest.flatten(output_types)
+ flat_shapes = nest.flatten(output_shapes)
+ flat_classes = nest.flatten(output_classes)
+ flat_ret = []
+ for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
+ flat_classes):
+ if issubclass(flat_class, sparse_tensor_lib.SparseTensor):
+ flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
+ elif issubclass(flat_class, ops.Tensor):
+ flat_ret.append(TensorStructure(flat_type, flat_shape))
+ else:
+ # NOTE(mrry): Since legacy structures produced by iterators only
+ # comprise Tensors, SparseTensors, and nests, we do not need to support
+ # all structure types here.
+ raise TypeError(
+ "Could not build a structure for output class %r" % flat_type)
+
+ ret = nest.pack_sequence_as(output_classes, flat_ret)
+ if isinstance(ret, Structure):
+ return ret
+ else:
+ return NestedStructure(ret)
+
+ @staticmethod
+ def _register_custom_converter(type_object, converter_fn):
+ """Registers `converter_fn` for converting values of the given type.
+
+ Args:
+ type_object: A Python `type` object representing the type of values
+ accepted by `converter_fn`.
+ converter_fn: A function that takes one argument (an instance of the
+ type represented by `type_object`) and returns a `Structure`.
+ """
+ _STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn
+
# NOTE(mrry): The following classes make extensive use of non-public methods of
# their base class, so we disable the protected-access lint warning once here.
@@ -179,16 +243,21 @@ class NestedStructure(Structure):
def _flat_types(self):
return self._flat_types_list
- def is_compatible_with(self, value):
+ def is_compatible_with(self, other):
+ if not isinstance(other, NestedStructure):
+ return False
try:
- nest.assert_shallow_structure(self._nested_structure, value)
+ # pylint: disable=protected-access
+ nest.assert_same_structure(self._nested_structure,
+ other._nested_structure)
except (ValueError, TypeError):
return False
return all(
- s.is_compatible_with(v) for s, v in zip(
+ substructure.is_compatible_with(other_substructure)
+ for substructure, other_substructure in zip(
nest.flatten(self._nested_structure),
- nest.flatten_up_to(self._nested_structure, value)))
+ nest.flatten(other._nested_structure)))
def _to_tensor_list(self, value):
ret = []
@@ -201,7 +270,7 @@ class NestedStructure(Structure):
for sub_value, structure in zip(flat_value,
nest.flatten(self._nested_structure)):
- if not structure.is_compatible_with(sub_value):
+ if not structure.is_compatible_with(Structure.from_value(sub_value)):
raise ValueError("Component value %r is not compatible with the nested "
"structure %r." % (sub_value, structure))
ret.extend(structure._to_tensor_list(sub_value))
@@ -242,17 +311,13 @@ class TensorStructure(Structure):
def _flat_types(self):
return [self._dtype]
- def is_compatible_with(self, value):
- try:
- value = ops.convert_to_tensor(value, dtype=self._dtype)
- except (ValueError, TypeError):
- return False
-
- return (self._dtype.is_compatible_with(value.dtype) and
- self._shape.is_compatible_with(value.shape))
+ def is_compatible_with(self, other):
+ return (isinstance(other, TensorStructure) and
+ self._dtype.is_compatible_with(other._dtype) and
+ self._shape.is_compatible_with(other._shape))
def _to_tensor_list(self, value):
- if not self.is_compatible_with(value):
+ if not self.is_compatible_with(Structure.from_value(value)):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
return [value]
@@ -260,7 +325,7 @@ class TensorStructure(Structure):
def _from_tensor_list(self, flat_value):
if len(flat_value) != 1:
raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
- if not self.is_compatible_with(flat_value[0]):
+ if not self.is_compatible_with(Structure.from_value(flat_value[0])):
raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
"%s." % (flat_value[0], self._dtype, self._shape))
return flat_value[0]
@@ -285,16 +350,10 @@ class SparseTensorStructure(Structure):
def _flat_types(self):
return [dtypes.variant]
- def is_compatible_with(self, value):
- try:
- value = sparse_tensor_lib.SparseTensor.from_value(value)
- except TypeError:
- return False
- return (isinstance(value, (sparse_tensor_lib.SparseTensor,
- sparse_tensor_lib.SparseTensorValue)) and
- self._dtype.is_compatible_with(value.dtype) and
- self._dense_shape.is_compatible_with(
- tensor_util.constant_value_as_shape(value.dense_shape)))
+ def is_compatible_with(self, other):
+ return (isinstance(other, SparseTensorStructure) and
+ self._dtype.is_compatible_with(other._dtype) and
+ self._dense_shape.is_compatible_with(other._dense_shape))
def _to_tensor_list(self, value):
return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index d0c7df67ae..2982763181 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -25,7 +25,9 @@ from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -106,13 +108,17 @@ class StructureTest(test.TestCase, parameterized.TestCase):
indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
}, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
)
- def testIsCompatibleWith(self, original_value, compatible_values,
- incompatible_values):
+ def testIsCompatibleWithStructure(self, original_value, compatible_values,
+ incompatible_values):
s = structure.Structure.from_value(original_value)
for compatible_value in compatible_values:
- self.assertTrue(s.is_compatible_with(compatible_value))
+ self.assertTrue(
+ s.is_compatible_with(
+ structure.Structure.from_value(compatible_value)))
for incompatible_value in incompatible_values:
- self.assertFalse(s.is_compatible_with(incompatible_value))
+ self.assertFalse(
+ s.is_compatible_with(
+ structure.Structure.from_value(incompatible_value)))
# NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
# will be executed before the (eager- or graph-mode) test environment has been
@@ -322,6 +328,28 @@ class StructureTest(test.TestCase, parameterized.TestCase):
ValueError, "Expected 3 flat values in NestedStructure but got 2."):
s_2._from_tensor_list(flat_s_1)
+ @parameterized.named_parameters(
+ ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", dtypes.int32, tensor_shape.matrix(2, 2),
+ sparse_tensor.SparseTensor,
+ structure.SparseTensorStructure(dtypes.int32, [2, 2])),
+ ("Nest",
+ {"a": dtypes.float32, "b": (dtypes.int32, dtypes.string)},
+ {"a": tensor_shape.scalar(),
+ "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())},
+ {"a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor)},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
+ structure.TensorStructure(dtypes.string, []))})),
+ )
+ def testFromLegacyStructure(self, output_types, output_shapes, output_classes,
+ expected_structure):
+ actual_structure = structure.Structure._from_legacy_structure(
+ output_types, output_shapes, output_classes)
+ self.assertTrue(expected_structure.is_compatible_with(actual_structure))
+ self.assertTrue(actual_structure.is_compatible_with(expected_structure))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 55231954d1..4630bda590 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -57,7 +57,8 @@ def no_rewrite_session_config():
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
- dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
diff --git a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
index 676097fde9..1f67f8a0d4 100644
--- a/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
+++ b/tensorflow/python/debug/lib/debug_graph_reconstruction_test.py
@@ -45,6 +45,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
def _no_rewrite_session_config(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF,
min_graph_nodes=-1)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
return config_pb2.ConfigProto(graph_options=graph_options)
@@ -156,7 +157,7 @@ class ReconstructNonDebugGraphTest(test_util.TensorFlowTestCase):
sess, cond, expected_output=21.0)
def testReconstructGraphWithWhileLoop(self):
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
loop_body = lambda i: math_ops.add(i, 2)
loop_cond = lambda i: math_ops.less(i, 16)
i = constant_op.constant(10, name="i")
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index ff49b69547..91f21cb1f3 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -741,7 +741,7 @@ class DelayedDebugServerTest(test_util.TensorFlowTestCase):
debug_server) = grpc_debug_test_server.start_server_on_separate_thread(
server_start_delay_sec=2.0, dump_to_filesystem=False)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a_init = constant_op.constant(42.0, name="a_init")
a = variables.Variable(a_init, name="a")
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index e17a598123..8daa34c885 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -182,6 +182,7 @@ def should_run_distribute_coordinator(config):
# pylint: disable=protected-access
if (not hasattr(config, '_distribute_coordinator_mode') or
config._distribute_coordinator_mode is None):
+ logging.info('Not using Distribute Coordinator.')
return False
if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
config._distribute_coordinator_mode not in [
@@ -221,15 +222,28 @@ def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._train_distribute = strategy
- _init_run_config_from_worker_context(
- local_estimator._config, dc_context.get_current_worker_context())
+ context = dc_context.get_current_worker_context()
+ _init_run_config_from_worker_context(local_estimator._config, context)
+ logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._train_distribution = strategy
# pylint: enable=protected-access
+ # In the standalone client, we don't need to run hooks on all threads
+ # because logging hooks on all threads may be too much on the screen; also
+ # tensor passed to one hook can only be fetched with the graph where the
+ # tensor is defined. Other hooks such as checkpointing hooks will added by
+ # MonitoredTrainingSession.
+ # TODO(yuefengz): Is there a hook that does need to run on all threads in
+ # standalone client mode?
+ if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access
+ dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
+ hooks = list(train_spec.hooks)
+ else:
+ hooks = []
local_estimator.train(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
- hooks=list(train_spec.hooks))
+ hooks=hooks)
def _eval_fn(strategy):
"""Function for evaluator task."""
@@ -238,6 +252,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
local_estimator._config._eval_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
+ logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._eval_distribution = strategy
executor = executor_cls(local_estimator, train_spec, eval_spec)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 85da1baaf0..d3d997e6df 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -17,7 +17,10 @@ cc_library(
"pywrap_tensor.h",
"pywrap_tfe.h",
],
- visibility = ["//tensorflow:internal"],
+ visibility = [
+ "//learning/deepmind/courier:__pkg__",
+ "//tensorflow:internal",
+ ],
deps = [
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
@@ -34,6 +37,7 @@ cc_library(
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/types:variant",
],
)
@@ -45,6 +49,7 @@ py_library(
":backprop",
":context",
":core",
+ ":def_function",
":execute",
":function",
":graph_only_ops",
@@ -146,6 +151,7 @@ cuda_py_test(
"//tensorflow/python:clip_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
+ "//tensorflow/python:list_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
],
@@ -345,6 +351,7 @@ py_test(
deps = [
":backprop",
":context",
+ ":core",
":test",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
@@ -377,3 +384,30 @@ cuda_py_test(
"optonly", # The test is too slow in non-opt mode
],
)
+
+py_library(
+ name = "def_function",
+ srcs = ["def_function.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":context",
+ ":function",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/training/checkpointable:base",
+ ],
+)
+
+py_test(
+ name = "def_function_test",
+ srcs = ["def_function_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":def_function",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ ],
+)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index be392c7a0f..78f3198011 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -120,27 +120,6 @@ def _gradient_function(op_name, attr_tuple, num_inputs, inputs, outputs,
pywrap_tensorflow.TFE_Py_RegisterGradientFunction(_gradient_function)
-_tracing = False
-
-
-# TODO(agarwal): use an automatic mechanism for handling None arguments to
-# gradient functions.
-# Some gradient functions can accept None arguments for gradients. The following
-# maps the operation name to the indices at which the corresponding gradient
-# function can accept None values.
-# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
-# during backprop. However the gradient function uses only the first of those
-# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
-# indicates that only the gradient corresponding to index 0 is used, and the
-# gradient values at indices 1-4 are ignored (and hence can be None). The
-# backprop algorithm can then leverage this by not constructing zeros to
-# pass for those indices.
-_grad_fn_accepts_none_for_indices = {
- "SoftmaxCrossEntropyWithLogits": [1],
- "FusedBatchNorm": [1, 2, 3, 4]
-}
-
-
def _record_gradient(op_name, inputs, attrs, results, name):
return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs,
results, name)
@@ -585,7 +564,10 @@ def _aggregate_grads(gradients):
def _num_elements(grad):
"""The number of elements in the `grad` tensor."""
if isinstance(grad, ops.Tensor):
- return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access
+ shape_tuple = grad._shape_tuple() # pylint: disable=protected-access
+ if shape_tuple is None or None in shape_tuple:
+ return 0
+ return functools.reduce(operator.mul, shape_tuple, 1)
if isinstance(grad, ops.IndexedSlices):
return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access
raise ValueError("`grad` not a Tensor or IndexedSlices.")
@@ -629,8 +611,9 @@ def _ones(shape, dtype):
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
- zeros=_zeros,
- ones=_ones)
+ zeros_fn=_zeros,
+ ones_fn=_ones,
+ graph_shape_fn=gen_array_ops.shape)
pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
@@ -648,8 +631,8 @@ class GradientTape(object):
Operations are recorded if they are executed within this context manager and
at least one of their inputs is being "watched".
- Trainable variables (created by `tf.Variable` or `tf.get_variable`,
- trainable=True is default in both cases) are automatically watched. Tensors
+ Trainable variables (created by `tf.Variable` or `tf.get_variable`, where
+ `trainable=True` is default in both cases) are automatically watched. Tensors
can be manually watched by invoking the `watch` method on this context
manager.
@@ -669,6 +652,7 @@ class GradientTape(object):
```python
x = tf.constant(3.0)
with tf.GradientTape() as g:
+ g.watch(x)
with tf.GradientTape() as gg:
gg.watch(x)
y = x * x
@@ -745,7 +729,9 @@ class GradientTape(object):
self._persistent = persistent
self._watch_accessed_variables = watch_accessed_variables
self._recording = False
- context.context().start_step()
+ self._created_eagerly = context.executing_eagerly()
+ if self._created_eagerly:
+ context.context().start_step()
def __enter__(self):
"""Enters a context inside which operations are recorded on this tape."""
@@ -775,7 +761,8 @@ class GradientTape(object):
self._recording = False
def __del__(self):
- context.context().end_step()
+ if self._created_eagerly:
+ context.context().end_step()
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index f938ed5df8..32731747b7 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -1022,6 +1022,18 @@ class BackpropTest(test.TestCase):
resource_variable_ops.ResourceVariable(2.0))
self.assertAllEqual(gradients_constants, gradients_variables)
+ def testUnknownShapes(self):
+ with context.graph_mode():
+ with backprop.GradientTape() as tape:
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ tape.watch(a)
+ b = a**3
+
+ db_da = tape.gradient(b, a)
+
+ with self.cached_session() as sess:
+ self.assertEqual((8.0, 12.0), sess.run((b, db_da), feed_dict={a: 2.0}))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
new file mode 100644
index 0000000000..8dcacd5c99
--- /dev/null
+++ b/tensorflow/python/eager/def_function.py
@@ -0,0 +1,235 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=unidiomatic-typecheck
+"""Prototype decorator for defining graph-mode functions with eager semantics."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training.checkpointable import base as checkpointable
+
+
+class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
+ """Variable which does not lift its initializer out of function context.
+
+ Instances of this variable, when created, build a graph which runs their
+ initializer inside a tf.cond(is_initialized) block.
+
+ This can only be created inside a defun called from (eventually) eager
+ mode. That is, non-function-building graphs are not supported.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ caching_device=None,
+ name=None,
+ dtype=None,
+ constraint=None,
+ **unused_kwargs):
+ """Creates a variable.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called.
+ (Note that initializer functions from init_ops.py must first be bound
+ to a shape before being used here.)
+ trainable: If `True`, GradientTapes automatically watch uses of this
+ Variable.
+ caching_device: Optional device string or function describing where the
+ Variable should be cached for reading. Defaults to the Variable's
+ device. If not `None`, caches on another device. Typical use is to
+ cache on the device where the Ops using the Variable reside, to
+ deduplicate copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If None, either the datatype will be kept (if initial_value is
+ a Tensor) or float32 will be used (if it is a Python object convertible
+ to a Tensor).
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+
+ Raises:
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If called outside of a function definition.
+ """
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "UnliftedInitializerVariable should not be created "
+ "outside of functions.")
+ with ops.init_scope():
+ if not context.executing_eagerly():
+ raise RuntimeError(
+ "UnliftedInitializerVariable does not support legacy graph mode.")
+ self._in_graph_mode = False
+ if initial_value is None:
+ raise ValueError("initial_value must be specified.")
+ init_from_fn = callable(initial_value)
+
+ if constraint is not None and not callable(constraint):
+ raise ValueError("The `constraint` argument must be a callable.")
+
+ if isinstance(initial_value, checkpointable.CheckpointInitialValue):
+ self._maybe_initialize_checkpointable()
+ self._update_uid = initial_value.checkpoint_position.restore_uid
+ initial_value = initial_value.wrapped_value
+
+ self._trainable = trainable
+ self._save_slice_info = None
+ self._initial_value = None
+ self._initializer_op = None
+ self._is_initialized_op = None
+ self._graph_element = None
+ self._cached_value = None
+ # Store the graph key so optimizers know how to only retrieve variables from
+ # this graph. Guaranteed to be the same as the eager graph_key.
+ self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
+ with ops.name_scope(name, "Variable", []
+ if init_from_fn else [initial_value]) as name:
+ # pylint: disable=protected-access
+ with ops.init_scope():
+ assert context.executing_eagerly()
+ shared_name = ops._name_from_scope_name(name)
+ shared_name = "%s_%d" % (shared_name, ops.uid())
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ with ops.name_scope("Initializer"), ops.device(None):
+ initial_value = ops.convert_to_tensor(
+ initial_value() if init_from_fn else initial_value,
+ name="initial_value", dtype=dtype)
+ with ops.init_scope():
+ self._handle = resource_variable_ops.eager_safe_variable_handle(
+ shape=initial_value.get_shape(),
+ dtype=initial_value.dtype.base_dtype,
+ shared_name=shared_name,
+ name=name,
+ graph_mode=False)
+ self._shape = initial_value.shape
+ self._unique_id = shared_name
+ self._handle_name = shared_name + ":0"
+ self._dtype = initial_value.dtype.base_dtype
+ self._constraint = constraint
+ assert initial_value is not None
+ def assign_fn():
+ with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+ resource_variable_ops.assign_variable_op(
+ self._handle,
+ initial_value,
+ name=n)
+ # Returning values to keep tf.cond happy.
+ return ops.convert_to_tensor(1)
+ def not_assign_fn():
+ return ops.convert_to_tensor(0)
+ # Note: this cond is always guaranteed to run because we're inside a defun
+ # which will insert automatic control dependencies.
+ control_flow_ops.cond(
+ resource_variable_ops.var_is_initialized_op(self._handle),
+ not_assign_fn, assign_fn)
+
+ # After the handle has been created, set up a way to clean it up when
+ # executing eagerly. We'll hold the only reference to the deleter, so that
+ # when this object is garbage collected the deleter will be too. This
+ # means ResourceVariables can be part of reference cycles without those
+ # cycles being uncollectable.
+ self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._handle, handle_device=self._handle.device)
+ self._cached_shape_as_list = None
+
+
+def _defun_with_scope(scope, fn):
+
+ def wrapped_fn(*args, **kwds):
+ with variable_scope.variable_creator_scope(scope):
+ return fn(*args, **kwds)
+
+ return function.defun(wrapped_fn)
+
+
+def def_function(fn):
+ """Defines a function as per the "functions, not sessions" document."""
+
+ # Wrapping the values in lists to bypass python's lack of way to mutate
+ # symbols from an outer scope.
+ first_call = [True]
+ function_to_call = []
+
+ # TODO(apassos) represent this as an object and not as a closure.
+ def decorated_fn(*args, **kwds):
+ """Graph function for fn."""
+ if not first_call[0]:
+ return function_to_call[0](*args, **kwds)
+
+ first_call[0] = False
+ created_variables = []
+
+ def variable_creator_scope(unused_next_creator, **kwds):
+ """Creates UnliftedInitializerVariables and saves references to them."""
+ v = UnliftedInitializerVariable(**kwds)
+ created_variables.append(v)
+ return v
+
+ first_graph_function = _defun_with_scope(variable_creator_scope, fn)
+
+ # Force the definition of the function for these arguments
+ first_concrete = first_graph_function.get_concrete_function(*args, **kwds)
+
+ def invalid_creator_scope(*unused_args, **unused_kwds):
+ """Disables variable creation."""
+ raise ValueError(
+ "def_function-decorated function tried to create "
+ "variables on second call.")
+
+ second_graph_function = _defun_with_scope(invalid_creator_scope, fn)
+
+ function_to_call.append(second_graph_function)
+ if not created_variables:
+ # Note: this retracing might be unnecessary, but running the function
+ # forever in the scope which disallows variable creation is safer than not
+ # doing so.
+ return second_graph_function(*args, **kwds)
+
+ def fn_with_cond(*inner_args, **inner_kwds):
+ """Conditionally runs initialization if it's needed."""
+ condition = True
+ for variable in created_variables:
+ condition = condition and resource_variable_ops.var_is_initialized_op(
+ variable.handle)
+ # We want to call second_graph_function if possible because it avoids
+ # recomputing potentially expensive initializers.
+ return control_flow_ops.cond(
+ condition,
+ lambda: second_graph_function(*inner_args, **inner_kwds),
+ lambda: first_concrete(*inner_args, **inner_kwds))
+
+ return function.defun(fn_with_cond)(*args, **kwds)
+
+ return decorated_fn
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
new file mode 100644
index 0000000000..804436c4bb
--- /dev/null
+++ b/tensorflow/python/eager/def_function_test.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class DefFunctionTest(test.TestCase):
+
+ def testNoVariables(self):
+
+ @def_function.def_function
+ def fn(x):
+ return 2 * x
+
+ self.assertAllEqual(fn(constant_op.constant(4.0)), 8.0)
+
+ def testFailIfVariablesAreCreatedMoreThanOnce(self):
+
+ @def_function.def_function
+ def fn(x):
+ return variables.Variable(1.0) + x
+
+ with self.assertRaises(ValueError):
+ fn(1.0)
+
+ def testFailIfVariablesAreCreatedMoreThanOnceNoWeakRef(self):
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ state.append(variables.Variable(1.0))
+ return state[-1] + x
+
+ with self.assertRaises(ValueError):
+ fn(1.0)
+
+ def testCorrectVariableCreation(self):
+
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(2.0))
+ return state[0] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+
+ def testVariableInitializerNotConstant(self):
+
+ state = []
+
+ @def_function.def_function
+ def fn(x):
+ if not state:
+ state.append(variables.Variable(2.0 * x))
+ return state[0] * x
+
+ self.assertAllEqual(fn(constant_op.constant(1.0)), 2.0)
+ self.assertAllEqual(fn(constant_op.constant(3.0)), 6.0)
+
+
+if __name__ == '__main__':
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 03f12139f6..b28befeb62 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -23,10 +23,12 @@ import collections
import functools
import sys
import threading
+import weakref
import numpy as np
import six
+from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
@@ -34,6 +36,7 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
@@ -59,23 +62,47 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
-def _create_substitute_placeholder(value, name, dtype=None):
+# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
+WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+
+
+def _create_substitute_placeholder(value, name=None, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
# capturing placeholders outside of any control flow context.
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- if placeholder.dtype == dtypes_module.resource:
- if isinstance(value, ops.EagerTensor):
- handle_data = value._handle_data # pylint: disable=protected-access
+ _copy_handle_data(value, placeholder)
+ return placeholder
+
+
+def _copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes_module.resource or
+ target_t.dtype == dtypes_module.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
else:
- handle_data = resource_variable_ops.get_resource_handle_data(value)
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
if handle_data is not None and handle_data.is_set:
# pylint: disable=protected-access
- pywrap_tensorflow.SetResourceHandleShapeAndType(
- placeholder.graph._c_graph, placeholder._as_tf_output(),
- handle_data.SerializeToString())
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
# pylint: enable=protected-access
# Ensure that shapes and dtypes are propagated.
shapes, types = zip(*[(pair.shape, pair.dtype)
@@ -84,12 +111,10 @@ def _create_substitute_placeholder(value, name, dtype=None):
shapes = [[d.size for d in s.dim]
if not s.unknown_rank else None for s in shapes]
pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- placeholder._op._graph._c_graph, # pylint: disable=protected-access
- placeholder._as_tf_output(), # pylint: disable=protected-access
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
shapes, ranks, types)
- return placeholder
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
@@ -99,6 +124,44 @@ def _get_device_functions(ctx, graph):
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+def _parse_func_attrs(attributes):
+ """Convert the keyword arguments into function_def attributes.
+
+ Currently only support primitive types: bool, int, float and string.
+
+ Args:
+ attributes: the dictionary of attributes.
+ Returns:
+ A dict of attributes where the key is the name of attribute and the value
+ is the AttrValue proto.
+ Raises:
+ ValueError: If the kwargs contains unwhitelisted name or unsupported value
+ types.
+ """
+ attrs = {}
+ for key, value in attributes.items():
+ if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+ raise ValueError("Attribute name is not whitelisted. "
+ "Whitelisted: prefix %s, got: %s" %
+ (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+
+ if isinstance(value, attr_value_pb2.AttrValue):
+ attrs[key] = value
+ # bool type check has to happen before int since bool is a subclass of int.
+ elif isinstance(value, bool):
+ attrs[key] = attr_value_pb2.AttrValue(b=value)
+ elif isinstance(value, int):
+ attrs[key] = attr_value_pb2.AttrValue(i=value)
+ elif isinstance(value, float):
+ attrs[key] = attr_value_pb2.AttrValue(f=value)
+ elif isinstance(value, str):
+ attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+ else:
+ raise ValueError("Unsupported attribute type for %s with type %s" %
+ (key, type(value)))
+ return attrs
+
+
class FuncGraph(ops.Graph):
"""Graph representing a function body.
@@ -136,7 +199,7 @@ class FuncGraph(ops.Graph):
self.inputs = []
self.outputs = []
self.structured_outputs = None
- self.variables = []
+ self._weak_variables = []
self.outer_graph = ops.get_default_graph()
self.captures = collections.OrderedDict()
@@ -173,6 +236,31 @@ class FuncGraph(ops.Graph):
self._graph_key = graph._graph_key
# pylint: enable=protected-access
+ @property
+ def variables(self):
+ """A list of variables accessed by this FuncGraph.
+
+ Note that functions keep only weak references to variables. Calling the
+ function after a variable it accesses has been deleted is an error.
+
+ Yields:
+ Strong references to variables accessed by this FuncGraph.
+ """
+ for weak_v in self._weak_variables:
+ v = weak_v()
+ if v is None:
+ raise AssertionError(
+ "Called a function referencing variables which have been deleted. "
+ "This likely means that function-local variables were created and "
+ "not referenced elsewhere in the program. This is generally a "
+ "mistake; consider storing variables in an object attribute on "
+ "first call.")
+ yield v
+
+ @variables.setter
+ def variables(self, var_list):
+ self._weak_variables = [weakref.ref(v) for v in var_list]
+
def create_op(
self,
op_type,
@@ -365,6 +453,7 @@ class _EagerDefinedFunction(object):
self._num_outputs = len(self.signature.output_arg)
self._output_types = [o.type for o in self.signature.output_arg]
self._output_shapes = [o.shape for o in outputs]
+ self._func_graph_outputs = outputs
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
@@ -441,6 +530,8 @@ class _EagerDefinedFunction(object):
else:
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
+ for i, func_graph_output in enumerate(self._func_graph_outputs):
+ _copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -485,7 +576,7 @@ class Function(object):
self._num_outputs = len(self._func_graph.outputs)
self._output_shapes = tuple(
output.shape for output in self._func_graph.outputs)
- self._attrs = attrs or {}
+ self._attrs = _parse_func_attrs(attrs or {})
self._device_functions = tuple(
self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
@@ -506,7 +597,19 @@ class Function(object):
self._distributed_variables[component_variable.handle] = variable
def __call__(self, *args):
- """Executes the wrapped function."""
+ """Executes the wrapped function.
+
+ Args:
+ *args: a list of Tensors or Variables.
+
+ Returns:
+ The result of applying the TF function to `args`.
+
+ Raises:
+ ValueError: If the current device stack does not match the device stack
+ under which the function was created, or if `args` contains anything
+ other than Tensors or Variables.
+ """
ctx = context.context()
device_functions = _get_device_functions(ctx, ops.get_default_graph())
if device_functions != self._device_functions:
@@ -522,7 +625,18 @@ class Function(object):
tape.variable_accessed(v)
captures = self._resolve_captured_inputs()
- tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
+ tensor_inputs = []
+ for i, arg in enumerate(nest.flatten(args)):
+ if isinstance(arg, resource_variable_ops.ResourceVariable):
+ if arg.trainable:
+ tape.variable_accessed(arg)
+ tensor_inputs.append(arg.handle)
+ elif isinstance(arg, ops.Tensor):
+ tensor_inputs.append(arg)
+ else:
+ raise ValueError("All inputs to `Function`s must be Tensors; "
+ "on invocation of %s, the %d-th input (%s) was not a "
+ "Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captures
if tape.should_record(tensor_inputs) or tape.should_record(captures):
@@ -537,11 +651,6 @@ class Function(object):
return self._func_graph
@property
- def variables(self):
- """Returns all variables touched by this function."""
- return self._func_graph.variables
-
- @property
def inputs(self):
"""Returns tensors in `self.graph` corresponding to arguments."""
return self._func_graph.inputs
@@ -738,7 +847,12 @@ def _get_defun_inputs_from_args(args):
return nest.pack_sequence_as(args, function_inputs)
-def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name,
+ python_func,
+ args,
+ kwargs,
+ signature=None,
+ func_graph=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -746,13 +860,15 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
python_func: the Python function to trace.
args: the positional args with which the Python function should be called;
ignored if a signature is provided.
- kwds: the keyword args with which the Python function should be called;
+ kwargs: the keyword args with which the Python function should be called;
ignored if a signature is provided.
signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
and dtypes of the arguments. When a signature is provided, `args` and
- `kwds` are ignored, and `python_func` is traced with Tensors conforming
+ `kwargs` are ignored, and `python_func` is traced with Tensors conforming
to `signature`. If `None`, the shapes and dtypes are inferred from the
inputs.
+ func_graph: Optional. An instance of FuncGraph. If provided, we will use
+ this graph else a new one is built and returned.
Returns:
A FuncGraph.
@@ -761,26 +877,25 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
- func_graph = FuncGraph(name)
+ if func_graph is None:
+ func_graph = FuncGraph(name)
+ assert isinstance(func_graph, FuncGraph)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
if signature is None:
func_args = _get_defun_inputs_from_args(args)
- func_kwds = _get_defun_inputs_from_args(kwds)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
else:
func_args = _get_defun_inputs_from_signature(signature)
- func_kwds = {}
+ func_kwargs = {}
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
- func_graph.inputs.extend(
- x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
- if isinstance(x, ops.Tensor))
-
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
- func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds))
+ func_kwargs_before = nest.pack_sequence_as(
+ func_kwargs, nest.flatten(func_kwargs))
def convert(x):
"""Converts an argument to a Tensor."""
@@ -799,7 +914,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwds)
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -819,10 +934,32 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
raise ValueError(errmsg)
check_mutation(func_args_before, func_args)
- check_mutation(func_kwds_before, func_kwds)
+ check_mutation(func_kwargs_before, func_kwargs)
finally:
tape.pop_tape(this_tape)
+ # Variables in `func_args`, `func_kwargs` should be explicit inputs
+ # to the function, not captured inputs.
+ tape_variables = this_tape.watched_variables()
+ arg_variables = set()
+ inputs = []
+ for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
+ if isinstance(arg, resource_variable_ops.ResourceVariable):
+ try:
+ resource_placeholder = func_graph.captures.pop(arg.handle)
+ arg_variables.add(arg)
+ except KeyError:
+ # This case occurs if a Variable among the inputs is not actually
+ # used by the function; we still add an explicit input for it
+ # because the user should presumably pass the Variable as an input
+ # to the corresponding graph function.
+ resource_placeholder = _create_substitute_placeholder(arg.handle)
+ inputs.append(resource_placeholder)
+ elif isinstance(arg, ops.Tensor):
+ inputs.append(arg)
+ variables = [v for v in tape_variables if v not in arg_variables]
+ func_graph.inputs = inputs + list(func_graph.captures.values())
+
func_graph.structured_outputs = func_outputs
# Returning a closed-over tensor does not trigger convert_to_tensor.
func_graph.outputs.extend(
@@ -834,7 +971,6 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
# Instead of storing non-distributed component variables, we
# store their distributed containers so we can retrieve the correct
# component variables at call-time.
- variables = list(this_tape.watched_variables())
strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
@@ -879,9 +1015,6 @@ def _encode_arg(arg):
_TensorType(arg.values.dtype, arg.values._shape_tuple()),
_TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
])
- elif isinstance(arg, np.ndarray):
- tensor = ops.convert_to_tensor(arg)
- return _TensorType(tensor.dtype, tensor._shape_tuple())
# pylint: enable=protected-access
elif isinstance(arg, (list, tuple)):
return tuple([_encode_arg(elem) for elem in arg])
@@ -889,7 +1022,16 @@ def _encode_arg(arg):
return tuple(
(_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
else:
- return arg
+ try:
+ # If possible, keep only a weak reference to Python objects. Weak
+ # references hash to the same value as the original object.
+ # TODO(allenl): Clean up dead functions and their cache keys if the cache
+ # gets large. Right now creating objects with a defunned method, calling
+ # the method, and losing a reference to the object in a loop will leak
+ # memory here.
+ return weakref.ref(arg)
+ except TypeError:
+ return arg
def _deterministic_dict_values(dictionary):
@@ -911,7 +1053,8 @@ class PolymorphicFunction(object):
def __init__(self,
python_function,
name,
- input_signature=None):
+ input_signature=None,
+ attributes=None):
"""Initializes a polymorphic function.
Args:
@@ -920,6 +1063,8 @@ class PolymorphicFunction(object):
input_signature: a possibly nested sequence of `TensorSpec` objects
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
+ attributes: dict, extra keyword arguments that will be added as attribute
+ of the function.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -929,14 +1074,14 @@ class PolymorphicFunction(object):
if isinstance(python_function, functools.partial):
self._python_function = python_function.func
self._args_to_prepend = python_function.args or tuple()
- self._kwds_to_include = python_function.keywords or {}
+ self._kwargs_to_include = python_function.keywords or {}
else:
self._python_function = python_function
self._args_to_prepend = tuple()
- self._kwds_to_include = {}
+ self._kwargs_to_include = {}
self._name = name
self._function_cache = collections.OrderedDict()
- self._variables = []
+ self._function_attributes = attributes or {}
self._lock = threading.Lock()
@@ -971,9 +1116,9 @@ class PolymorphicFunction(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
- def __call__(self, *args, **kwds):
+ def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
- graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ graph_function, inputs = self._maybe_define_function(args, kwargs)
return graph_function(*inputs)
@property
@@ -981,12 +1126,6 @@ class PolymorphicFunction(object):
"""Returns the wrapped Python function."""
return self._python_function
- # TODO(akshayka): Remove this property.
- @property
- def variables(self):
- """Returns the union of all variables referenced by cached `Function`s`."""
- return self._variables
-
def get_concrete_function(self, *args, **kwargs):
"""Returns a `Function` object specialized to inputs and execution context.
@@ -997,7 +1136,7 @@ class PolymorphicFunction(object):
*args: inputs to specialize on.
**kwargs: inputs to specialize on.
"""
- graph_function, _ = self._maybe_define_function(*args, **kwargs)
+ graph_function, _ = self._maybe_define_function(args, kwargs)
return graph_function
def __get__(self, instance, owner):
@@ -1018,33 +1157,37 @@ class PolymorphicFunction(object):
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwds, ctx, graph):
+ def _cache_key(self, args, kwargs):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
- inputs = (args, kwds) if kwds else args
+ inputs = (args, kwargs) if kwargs else args
cache_key = tuple(_encode_arg(arg) for arg in inputs)
else:
- del args, kwds
+ del args, kwargs
cache_key = self._flat_input_signature
- # The graph, or whether we're executing eagerly, should be a part of the
- # cache key so we don't improperly capture tensors such as variables.
- executing_eagerly = ctx.executing_eagerly()
- execution_context = executing_eagerly or graph
+ with ops.init_scope():
+ init_graph = ops.get_default_graph()
+
+ # The graph, or whether we're executing eagerly, should be a part of the
+ # cache key so we don't improperly capture tensors such as variables.
+ executing_eagerly = context.executing_eagerly()
+ execution_context = executing_eagerly or init_graph
+ default_graph = ops.get_default_graph()
# Putting the device in the cache key ensures that call-site device
# annotations are respected.
- device_functions = _get_device_functions(ctx, graph)
+ device_functions = _get_device_functions(context.context(), default_graph)
# `ops.colocate_with` directives translate into `ops.device` directives when
# eager execution is enabled.
- colocation_stack = (None if executing_eagerly else
- tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+ colocation_stack = (() if executing_eagerly else
+ tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
return cache_key + (execution_context, device_functions, colocation_stack)
- def _canonicalize_function_inputs(self, *args, **kwds):
- """Canonicalizes `args` and `kwds`.
+ def _canonicalize_function_inputs(self, *args, **kwargs):
+ """Canonicalizes `args` and `kwargs`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
@@ -1054,28 +1197,28 @@ class PolymorphicFunction(object):
Args:
*args: The varargs this object was called with.
- **kwds: The keyword args this function was called with.
+ **kwargs: The keyword args this function was called with.
Returns:
A canonicalized ordering of the inputs.
Raises:
- ValueError: If a keyword in `kwds` cannot be matched with a positional
+ ValueError: If a keyword in `kwargs` cannot be matched with a positional
argument when an input signature is specified, or when the inputs
do not conform to the input signature.
"""
args = self._args_to_prepend + args
- kwds = dict(kwds, **self._kwds_to_include)
+ kwargs = dict(kwargs, **self._kwargs_to_include)
# Maps from index of arg to its corresponding value, according to `args`
- # and `kwds`; seeded with the default values for the named args that aren't
- # in `args`.
+ # and `kwargs`; seeded with the default values for the named args that
+ # aren't in `args`.
arg_indices_to_values = {
index: default
for index, default in six.iteritems(self._arg_indices_to_default_values)
if index >= len(args)
}
consumed_args = []
- for arg, value in six.iteritems(kwds):
+ for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
arg_indices_to_values[index] = value
@@ -1085,20 +1228,30 @@ class PolymorphicFunction(object):
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args:
- # After this loop, `kwds` will only contain true keyword arguments, as
+ # After this loop, `kwargs` will only contain true keyword arguments, as
# opposed to named arguments called in a keyword-like fashion.
- kwds.pop(arg)
+ kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ flat_inputs = nest.flatten(inputs)
+
+ # Check for NumPy arrays in arguments and convert them to Tensors.
+ need_packing = False
+ for index, value in enumerate(flat_inputs):
+ if isinstance(value, np.ndarray):
+ flat_inputs[index] = constant_op.constant(value)
+ need_packing = True
+ if need_packing:
+ inputs = nest.pack_sequence_as(structure=inputs,
+ flat_sequence=flat_inputs)
if self._input_signature is None:
- return inputs, kwds
+ return inputs, kwargs
else:
- assert not kwds
+ assert not kwargs
try:
nest.assert_same_structure(self._input_signature, inputs)
except (ValueError, TypeError):
raise ValueError("Structure of Python function inputs does not match "
"input_signature.")
- flat_inputs = nest.flatten(inputs)
if any(not isinstance(arg, ops.Tensor) for arg in flat_inputs):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
@@ -1112,25 +1265,27 @@ class PolymorphicFunction(object):
(str(inputs), str(self._input_signature)))
return inputs, {}
- def _maybe_define_function(self, *args, **kwds):
+ def _maybe_define_function(self, args, kwargs):
"""Gets a function for these inputs, defining it if necessary.
+ `args` and `kwargs` can be None if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
Args:
- *args: args for the Python function.
- **kwds: keywords for the Python function.
+ args: The varargs for the Python function.
+ kwargs: The keyword args for the Python function.
Returns:
A graph function corresponding to the input signature implied by args and
- kwds, as well as the inputs that the object should be called with.
+ kwargs, as well as the inputs that the object should be called with.
Raises:
ValueError: If inputs are incompatible with the input signature.
TypeError: If the function inputs include non-hashable objects
"""
-
- args, kwds = self._canonicalize_function_inputs(*args, **kwds)
- cache_key = self._cache_key(args, kwds, context.context(),
- ops.get_default_graph())
+ if self._input_signature is None or args is not None or kwargs is not None:
+ args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
+ cache_key = self._cache_key(args, kwargs)
with self._lock:
try:
graph_function = self._function_cache.get(cache_key, None)
@@ -1141,11 +1296,41 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature))
- self._variables.extend(
- [v for v in graph_function.variables if v not in self._variables])
+ kwargs, self._input_signature),
+ self._function_attributes)
self._function_cache[cache_key] = graph_function
- return graph_function, (args, kwds)
+ return graph_function, [
+ t for t in nest.flatten((args, kwargs))
+ if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
+ ]
+
+
+def register(func, *args, **kwargs):
+ """Register the defun function into the graph.
+
+ This won't actually call the function with the inputs, and only put the
+ function definition into graph. Register function with different input param
+ will result into multiple version of functions registered in graph.
+
+ Args:
+ func: the PolymorphicFunction instance that generated by a @defun
+ *args: input arguments for the Python function.
+ **kwargs: input keyword arguments for the Python function.
+
+ Returns:
+ a `Function` object specialized to inputs and execution context.
+
+ Raises:
+ ValueError: When the input function is not a defun wrapped python function.
+ """
+ if not isinstance(func, PolymorphicFunction):
+ raise ValueError("Only defun function is allowed to be registered. "
+ "Got type: %s" % type(func))
+ concrete_func = func.get_concrete_function(*args, **kwargs)
+ graph = ops.get_default_graph()
+ concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access
+ # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+ return concrete_func
def _validate_signature(signature):
@@ -1271,6 +1456,11 @@ def defun(func=None, input_signature=None):
tracing the execution of `f(*args, **kwargs)`; this graph is bound to an
input signature inferred from `(*args, **kwargs)` and cached for future reuse.
+ NumPy arrays passed as inputs to `F` are converted to `tf.Tensor` objects
+ before being passed to `f`, and are treated as Tensors for caching. This
+ allows a function to be called multiple times with NumPy arrays having
+ different values but the same shape and dtype without re-tracing each time.
+
`tf.contrib.eager.defun` caches graphs for your convenience, letting you
define TensorFlow functions without explicitly specifying their signatures.
However, this policy is conservative and potentially expensive; for example,
@@ -1470,7 +1660,29 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
+ return defun_with_attributes(func=func, input_signature=input_signature)
+
+def defun_with_attributes(func=None, input_signature=None, attributes=None):
+ """Compiles a Python function into a callable TensorFlow graph.
+
+ This function supports adding extra function attributes. See detailed
+ documentation in defun(). Currently this is not exposed in public API since we
+ don't expect user to directly use attributes, and attribute won't work by
+ itself. This assumption might change in future.
+
+ Args:
+ func: function to be compiled.
+ input_signature: same as defun()'s input_signature.
+ attributes: A dictionary of arguments which will be added to function def as
+ attributes. Currently only support primitive types as value, and only
+ whitelisted attribute name is allowed. Unwhitelisted attribute name or
+ unsupported value will result into ValueError.
+
+ Returns:
+ Same as the return value of defun, with attributes added to the function in
+ graph.
+ """
if input_signature is not None:
_validate_signature(input_signature)
@@ -1482,7 +1694,8 @@ def defun(func=None, input_signature=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature))
+ PolymorphicFunction(function, name, input_signature=input_signature,
+ attributes=attributes))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1727,9 +1940,9 @@ def automatic_control_dependencies(f):
The wrapped function.
"""
- def wrapper(*args, **kwds):
+ def wrapper(*args, **kwargs):
with AutomaticControlDependencies() as a:
- result = f(*args, **kwds)
+ result = f(*args, **kwargs)
result_flat = [a.mark_as_return(t) for t in nest.flatten(result)]
return nest.pack_sequence_as(result, result_flat)
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 92254a2c00..59faf967c5 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -21,8 +21,13 @@ import collections
import functools
from multiprocessing.pool import ThreadPool
import sys
+import weakref
+
+import numpy
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python import keras
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -36,12 +41,14 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
@@ -55,6 +62,28 @@ from tensorflow.python.util import compat
from tensorflow.python.util import nest
+class MiniModel(keras_training.Model):
+ """Minimal model for mnist.
+
+ Useful for testing and debugging on slow TPU simulators.
+ """
+
+ def __init__(self):
+ super(MiniModel, self).__init__(name='')
+ self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
+ bias_initializer='ones')
+
+ def call(self, inputs, training=True):
+ return self.fc(inputs)
+
+
+class DefunnedMiniModel(MiniModel):
+
+ @function.defun
+ def call(self, inputs, training=True):
+ return super(DefunnedMiniModel, self).call(inputs, training=training)
+
+
@test_util.with_c_shapes
class FunctionTest(test.TestCase):
@@ -121,8 +150,8 @@ class FunctionTest(test.TestCase):
@function.defun
def f():
- v = resource_variable_ops.ResourceVariable(1.0)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ return self.v.read_value()
self.assertAllEqual(f(), 1.0)
@@ -314,6 +343,7 @@ class FunctionTest(test.TestCase):
def testDefunNumpyArraysConvertedToTensors(self):
def f(x):
+ self.assertIsInstance(x, ops.Tensor)
return x
x = random_ops.random_uniform([2, 2]).numpy()
@@ -327,6 +357,12 @@ class FunctionTest(test.TestCase):
# shouldn't trigger another function definition.
self.assertEqual(len(defined._function_cache), 1)
+ # Test that the numpy array is properly an argument to the graph function.
+ self.assertEqual(1., defined(numpy.ones([])).numpy())
+ self.assertEqual(0., defined(numpy.zeros([])).numpy())
+ self.assertEqual(1., defined(array_ops.ones([])).numpy())
+ self.assertEqual(0., defined(array_ops.zeros([])).numpy())
+
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@@ -373,9 +409,9 @@ class FunctionTest(test.TestCase):
@function.defun
def tensor_init():
- v = resource_variable_ops.ResourceVariable(
+ self.v = resource_variable_ops.ResourceVariable(
lambda: constant_op.constant(2.0))
- return v.read_value()
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -389,8 +425,8 @@ class FunctionTest(test.TestCase):
def tensor_init():
with ops.init_scope():
const = constant_op.constant(2.0)
- v = resource_variable_ops.ResourceVariable(const)
- return v.read_value()
+ self.v = resource_variable_ops.ResourceVariable(const)
+ return self.v.read_value()
value = tensor_init()
if not context.executing_eagerly():
@@ -403,10 +439,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testVariableInLoopInFunction(self):
@@ -430,10 +473,17 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
+ # We do not return v directly since the tensor conversion function of
+ # ResourceVariable returns the read value and not the resource itself.
+ return v._handle
compiled = function.defun(f)
- compiled()
+ var_handle = compiled()
+ self.assertEqual(var_handle.dtype, dtypes.resource)
+ self.assertEqual(var_handle.shape, tensor_shape.scalar())
+ var_t = resource_variable_ops.read_variable_op(var_handle, dtype=v.dtype)
+ self.assertEqual(var_t.shape, tensor_shape.TensorShape([2, 2]))
def testDefunShapeInferenceWithCapturedVariableInGraphMode(self):
with context.graph_mode():
@@ -442,23 +492,46 @@ class FunctionTest(test.TestCase):
def f():
x = constant_op.constant([[1, 2], [3, 4]])
out = math_ops.matmul(v, x)
- self.assertEqual(out.get_shape(), tensor_shape.TensorShape([2, 2]))
+ self.assertEqual(out.shape, tensor_shape.TensorShape([2, 2]))
# Check that shape inference works while creating the defun
compiled = function.defun(f)
compiled()
+ def testDefunShapeInferenceWithCapturedTensorListInGraphMode(self):
+ with context.graph_mode():
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(1.0))
+ tensor_list = list_ops.tensor_list_push_back(tensor_list,
+ constant_op.constant(2.0))
+
+ def f():
+ tl, value = list_ops.tensor_list_pop_back(
+ tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+ return tl
+
+ compiled = function.defun(f)
+ output_tensor_list = compiled()
+ _, value = list_ops.tensor_list_pop_back(
+ output_tensor_list, element_dtype=dtypes.float32)
+ self.assertEqual(value.shape, tensor_shape.scalar())
+
@test_util.run_in_graph_and_eager_modes
def testDefunForcesResourceVariables(self):
def variable_creator():
- return variables.Variable(0.0).read_value()
+ self.v = variables.Variable(0.0)
+ return self.v.read_value()
+ self.v = None
defined = function.defun(variable_creator)
defined() # Create the variable.
- self.assertEqual(len(defined.variables), 1)
self.assertIsInstance(
- defined.variables[0], resource_variable_ops.ResourceVariable)
+ self.v, resource_variable_ops.ResourceVariable)
def testDefunDifferentiable(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -996,6 +1069,7 @@ class FunctionTest(test.TestCase):
with ops.get_default_graph().as_default():
create_variable()
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testLayerInDefun(self):
conv = convolutional.Conv2D(
filters=1,
@@ -1009,7 +1083,34 @@ class FunctionTest(test.TestCase):
x = array_ops.ones([1, 2, 2, 1])
y = model(x)
- self.assertAllEqual([[[[4.0]]]], y.numpy())
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
+
+ # Remove reference cycles in model
+ test_util.dismantle_polymorphic_function(model)
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testDefunKerasModelCall(self):
+ model = MiniModel()
+ model.call = function.defun(model.call)
+
+ x = array_ops.ones([1, 2])
+ y = model(x)
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+
+ self.assertAllEqual([[3.0]], self.evaluate(y))
+
+ # Remove reference cycles in defun.
+ test_util.dismantle_polymorphic_function(model.call)
+ # Break the reference cycle between the MiniModel and the defun:
+ # MiniModel --(through its `call` method)--> PolymorphicFunction
+ # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
+ del model.call
# Note: The ConfigProto below unfortunately only configures graph
# construction. Eager's configuration is controlled in `__main__`.
@@ -1130,13 +1231,11 @@ class FunctionTest(test.TestCase):
defined = function.defun(foo)
x = constant_op.constant([1.0])
- self.assertAllEqual(defined.variables, [])
- _ = defined(x)
- self.assertAllEqual(defined.variables, [v])
+ self.assertEqual(1., self.evaluate(defined(x)))
+ v.assign(2.)
x = constant_op.constant([1.0, 2.0])
- _ = defined(x) # ensure the variables list remains the same
- self.assertAllEqual(defined.variables, [v])
+ self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
def testPythonFunctionWithDefaultArgs(self):
@@ -1492,6 +1591,257 @@ class FunctionTest(test.TestCase):
side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
+ def testFunctionWithExtraAttributes(self):
+ @function.defun_with_attributes(attributes={'experimental_1': 'value1',
+ 'experimental_2': 2})
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+
+ def add(x, y):
+ return math_ops.add(x, y)
+ defun_add = function.defun_with_attributes(
+ add, attributes={'experimental_3': True, 'experimental_4': 1.0})
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ sq = matmul(t, t)
+ double = defun_add(t, t)
+ self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+ self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ self.assertRegexpMatches(
+ functions[0].definition.signature.name, '.*matmul.*')
+ attrs = functions[0].definition.attr
+ self.assertEqual(len(attrs), 2)
+ self.assertEqual(attrs['experimental_1'].s, b'value1')
+ self.assertEqual(attrs['experimental_2'].i, 2)
+
+ self.assertRegexpMatches(
+ functions[1].definition.signature.name, '.*add.*')
+ attrs = functions[1].definition.attr
+ self.assertEqual(len(attrs), 2)
+ self.assertEqual(attrs['experimental_3'].b, True)
+ self.assertEqual(attrs['experimental_4'].f, 1.0)
+ # pylint: enable=protected-access
+
+ def testFunctionWithInvalidAttribute(self):
+ @function.defun_with_attributes(attributes={'attr1': 'value1'})
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+
+ with self.assertRaisesRegexp(ValueError,
+ '.*Attribute name is not whitelisted.*'):
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ matmul(t, t)
+
+ @function.defun_with_attributes(attributes={'experimental_1': ['value1']})
+ def add(x, y):
+ return math_ops.add(x, y)
+
+ with self.assertRaisesRegexp(ValueError,
+ '.*Unsupported attribute type.*'):
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ add(t, t)
+
+ def testRegisterFunction(self):
+ @function.defun
+ def add(x, y):
+ return math_ops.add(x, y)
+
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(matmul)
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ function.register(defun_matmul, t, t)
+ function.register(add, t, t)
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ pre_register_matmul_func_name = functions[0].definition.signature.name
+ self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*')
+ pre_register_add_func_name = functions[1].definition.signature.name
+ self.assertRegexpMatches(pre_register_add_func_name, '.*add.*')
+
+ sq = defun_matmul(t, t)
+ double = add(t, t)
+ self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+ self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+ # Make sure the pre registered function is used, and no other function
+ # is added.
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ called_func_name = functions[0].definition.signature.name
+ self.assertEqual(pre_register_matmul_func_name, called_func_name)
+ called_func_name = functions[1].definition.signature.name
+ self.assertEqual(pre_register_add_func_name, called_func_name)
+
+ def testRegisterFunctionWithInputSignature(self):
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(
+ matmul,
+ input_signature=[
+ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
+ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
+ ])
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ function.register(defun_matmul, t, t)
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 1)
+
+ # Test input param shape mismatch
+ t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ with self.assertRaisesRegexp(
+ ValueError, 'Python inputs incompatible with input_signature'):
+ function.register(defun_matmul, t2, t2)
+
+ def testRegisterFunctionWithCache(self):
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(matmul)
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
+ function.register(defun_matmul, t, t)
+ function.register(defun_matmul, t2, t2)
+
+ graph = ops.get_default_graph()
+ # Only one function is registered since the input param are in same type
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 1)
+
+ def testCallingFunctionWithDifferentVariables(self):
+
+ @function.defun
+ def foo(v):
+ v.assign_add(1.0)
+ return v.read_value()
+
+ v = resource_variable_ops.ResourceVariable(0.0)
+ graph_function = foo.get_concrete_function(v)
+ self.assertEqual(len(graph_function.inputs), 1)
+ self.assertEqual(len(graph_function.captured_inputs), 0)
+
+ self.assertEqual(float(graph_function(v)), 1.0)
+ self.assertEqual(float(graph_function(v)), 2.0)
+
+ w = resource_variable_ops.ResourceVariable(0.0)
+
+ @function.defun
+ def bar(v):
+ del v
+ return constant_op.constant(1.0)
+
+ graph_function = bar.get_concrete_function(v)
+ self.assertEqual(float(graph_function(v)), 1.0)
+ self.assertEqual(float(graph_function(w)), 1.0)
+
+ def testCallingFunctionWithNonTensorsFails(self):
+
+ @function.defun
+ def foo(x):
+ return x
+
+ graph_function = foo.get_concrete_function(constant_op.constant(1.0))
+ with self.assertRaisesRegexp(ValueError, 'All inputs to `Function`s must '
+ 'be Tensors;.*'):
+ graph_function('Not a Tensor.')
+
+ def testSwapImplementationWithGrapplerPlugin(self):
+ rewrites = rewriter_config_pb2.RewriterConfig()
+ # function_optimizer has to be turn off, otherwise it will delete the
+ # registered function if it does not get called.
+ # TODO(scottzhu): Move the ExperimentalImplementationSelector to be called
+ # before function_optimizer in future.
+ rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
+ customer_optimizer = rewrites.custom_optimizers.add()
+ customer_optimizer.name = 'ExperimentalImplementationSelector'
+ rewrites.min_graph_nodes = -1
+ graph_options = config_pb2.GraphOptions(
+ rewrite_options=rewrites, build_cost_model=1)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+
+ with context.graph_mode(), self.cached_session(
+ config=config, graph=ops.Graph(), use_gpu=True) as sess:
+
+ @function.defun_with_attributes(
+ attributes={
+ 'experimental_api_implements': 'random_boost',
+ 'experimental_api_preferred_device': 'CPU'
+ })
+ def cpu_boost(x):
+ return math_ops.add(x, 2.0)
+
+ @function.defun_with_attributes(
+ attributes={
+ 'experimental_api_implements': 'random_boost',
+ 'experimental_api_preferred_device': 'GPU'
+ })
+ def gpu_boost(x):
+ return math_ops.add(x, 4.0)
+
+ x = constant_op.constant(1.0)
+
+ function.register(cpu_boost, x)
+ y = gpu_boost(x)
+ y_value = sess.run(y)
+
+ if test.is_gpu_available():
+ self.assertEquals(y_value, 5.0)
+ else:
+ # Grappler fallback to use the CPU impl even called with GPU function.
+ self.assertEquals(y_value, 3.0)
+
+ def testDefunFunctionSeparateGraphs(self):
+ with context.graph_mode():
+
+ @function.defun
+ def add(x):
+ return x + 5
+
+ @function.defun
+ def maybe_add(x, should_add):
+ if should_add:
+ return add(x)
+ else:
+ return x
+
+ with ops.Graph().as_default():
+ x = constant_op.constant(11)
+ maybe_add(x, True)
+ self.assertEqual(len(maybe_add._function_cache), 1)
+ self.assertEqual(len(add._function_cache), 1)
+
+ maybe_add(x, False)
+ self.assertEqual(len(maybe_add._function_cache), 2)
+ self.assertEqual(len(add._function_cache), 1)
+
+ with ops.Graph().as_default():
+ x = constant_op.constant(11)
+ maybe_add(x, True)
+ self.assertEqual(len(maybe_add._function_cache), 3)
+ self.assertEqual(len(add._function_cache), 2)
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):
@@ -1683,10 +2033,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
value = train()
self.assertEqual(value.numpy(), -1.0)
@@ -1713,10 +2063,10 @@ class AutomaticControlDependenciesTest(test.TestCase):
@function.defun
def train():
- v = resource_variable_ops.ResourceVariable(1.0)
- grad = backprop.implicit_grad(loss)(v)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
+ grad = backprop.implicit_grad(loss)(self.v)
optimizer.apply_gradients(grad)
- return v.read_value()
+ return self.v.read_value()
train()
@@ -1903,6 +2253,13 @@ class AutomaticControlDependenciesTest(test.TestCase):
modify_same_flat(nested_input)
+ def testDecoratedMethodVariableCleanup(self):
+ m = DefunnedMiniModel()
+ m(array_ops.ones([1, 2]))
+ weak_variables = weakref.WeakSet(m.variables)
+ self.assertEqual(2, len(weak_variables))
+ del m
+ self.assertEqual([], list(weak_variables))
if __name__ == '__main__':
ops.enable_eager_execution(
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f027d107c..5f5af4ab6c 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -23,8 +23,9 @@ import collections
from tensorflow.python import pywrap_tensorflow
-VSpace = collections.namedtuple(
- "VSpace", ["aggregate_fn", "num_elements_fn", "zeros", "ones"])
+VSpace = collections.namedtuple("VSpace", [
+ "aggregate_fn", "num_elements_fn", "zeros_fn", "ones_fn", "graph_shape_fn"
+])
def imperative_grad(
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index f34ce6af79..5f44bd4fec 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -516,25 +516,13 @@ static PyObject* EagerTensor_rank(EagerTensor* self) {
// Getter for `_num_elements`.
static PyObject* EagerTensor_num_elements(EagerTensor* self) {
auto handle = self->handle;
- int n = TFE_TensorHandleNumDims(handle, self->status);
+ int n = TFE_TensorHandleNumElements(handle, self->status);
if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
// Cleanup self->status before returning.
TF_SetStatus(self->status, TF_OK, "");
return nullptr;
}
- tensorflow::int64 value = 1;
- if (PyErr_Occurred()) return nullptr;
- for (int i = 0; i < n; ++i) {
- int64_t dim = TFE_TensorHandleDim(handle, i, self->status);
- if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
- // Cleanup self->status before returning.
- TF_SetStatus(self->status, TF_OK, "");
- PyErr_SetString(PyExc_RuntimeError, "Error while iterating dimensions");
- return nullptr;
- }
- value *= dim;
- }
- return PyLong_FromLongLong(value);
+ return PyLong_FromLongLong(n);
}
static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
@@ -777,17 +765,34 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
return reinterpret_cast<PyObject*>(t);
}
-tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return reinterpret_cast<const EagerTensor*>(tensor)->id;
}
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor) {
- CHECK(EagerTensor_CheckExact(tensor));
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
reinterpret_cast<const EagerTensor*>(tensor)->handle));
}
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
+ DCHECK(EagerTensor_CheckExact(tensor));
+ const EagerTensor* as_c_eager_tensor =
+ reinterpret_cast<const EagerTensor*>(tensor);
+ tensorflow::int64 result = TFE_TensorHandleNumElements(
+ as_c_eager_tensor->handle, as_c_eager_tensor->status);
+
+ if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
+ PyExc_ValueError)) {
+ // Cleanup status before returning.
+ TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
+ return -1;
+ }
+
+ return result;
+}
+
PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
if (!PyType_Check(base_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
index bc042eb19e..4eaa1ba536 100644
--- a/tensorflow/python/eager/pywrap_tensor.h
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -21,8 +21,9 @@ limitations under the License.
#include "tensorflow/python/lib/core/numpy.h"
bool EagerTensor_CheckExact(const PyObject* o);
-tensorflow::int64 EagerTensor_id(const PyObject* tensor);
-tensorflow::DataType EagerTensor_dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor);
+tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor);
+tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor);
namespace tensorflow {
TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype);
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 46dcf7c8a8..159b1c1218 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -860,7 +861,7 @@ static tensorflow::int64 MakeInt(PyObject* integer) {
static tensorflow::int64 FastTensorId(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_id(tensor);
+ return PyEagerTensor_ID(tensor);
}
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
if (id_field == nullptr) {
@@ -873,7 +874,7 @@ static tensorflow::int64 FastTensorId(PyObject* tensor) {
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
- return EagerTensor_dtype(tensor);
+ return PyEagerTensor_Dtype(tensor);
}
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
if (dtype_field == nullptr) {
@@ -889,12 +890,239 @@ static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
return static_cast<tensorflow::DataType>(id);
}
+class PyTapeTensor {
+ public:
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ const tensorflow::TensorShape& shape)
+ : id_(id), dtype_(dtype), shape_(shape) {}
+ PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
+ PyObject* shape)
+ : id_(id), dtype_(dtype), shape_(shape) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ PyTapeTensor(const PyTapeTensor& other) {
+ id_ = other.id_;
+ dtype_ = other.dtype_;
+ shape_ = other.shape_;
+ if (shape_.index() == 1) {
+ Py_INCREF(absl::get<1>(shape_));
+ }
+ }
+
+ ~PyTapeTensor() {
+ if (shape_.index() == 1) {
+ Py_DECREF(absl::get<1>(shape_));
+ }
+ }
+ PyObject* GetShape() const;
+ PyObject* GetDType() const { return PyLong_FromLong(dtype_); }
+ tensorflow::int64 GetID() const { return id_; }
+
+ private:
+ tensorflow::int64 id_;
+ tensorflow::DataType dtype_;
+ absl::variant<tensorflow::TensorShape, PyObject*> shape_;
+};
+
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
+ public:
+ explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
+ Py_INCREF(py_vspace_);
+ }
+
+ tensorflow::Status Initialize() {
+ num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+ if (num_elements_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+ if (aggregate_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
+ if (zeros_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
+ if (ones_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
+ if (graph_shape_fn_ == nullptr) {
+ return tensorflow::errors::InvalidArgument("invalid vspace");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ ~PyVSpace() override {
+ Py_XDECREF(num_elements_);
+ Py_XDECREF(aggregate_fn_);
+ Py_XDECREF(zeros_fn_);
+ Py_XDECREF(ones_fn_);
+ Py_XDECREF(graph_shape_fn_);
+
+ Py_DECREF(py_vspace_);
+ }
+
+ tensorflow::int64 NumElements(PyObject* tensor) const final {
+ if (EagerTensor_CheckExact(tensor)) {
+ return PyEagerTensor_NumElements(tensor);
+ }
+ PyObject* arglist =
+ Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+ PyObject* result = PyEval_CallObject(num_elements_, arglist);
+ Py_DECREF(arglist);
+ if (result == nullptr) {
+ // The caller detects whether a python exception has been raised.
+ return -1;
+ }
+ tensorflow::int64 r = MakeInt(result);
+ Py_DECREF(result);
+ return r;
+ }
+
+ PyObject* AggregateGradients(
+ tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
+ PyObject* list = PyList_New(gradient_tensors.size());
+ for (int i = 0; i < gradient_tensors.size(); ++i) {
+ // Note: stealing a reference to the gradient tensors.
+ CHECK(gradient_tensors[i] != nullptr);
+ CHECK(gradient_tensors[i] != Py_None);
+ PyList_SET_ITEM(list, i,
+ reinterpret_cast<PyObject*>(gradient_tensors[i]));
+ }
+ PyObject* arglist = Py_BuildValue("(O)", list);
+ CHECK(arglist != nullptr);
+ PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+ Py_DECREF(arglist);
+ Py_DECREF(list);
+ return result;
+ }
+
+ void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
+
+ PyObject* Zeros(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return reinterpret_cast<PyObject*>(result);
+ }
+
+ PyObject* Ones(const PyTapeTensor& tensor) const final {
+ PyObject* py_shape = tensor.GetShape();
+ PyObject* py_dtype = tensor.GetDType();
+ PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+ PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
+ Py_DECREF(arg_list);
+ Py_DECREF(py_dtype);
+ Py_DECREF(py_shape);
+ return result;
+ }
+
+ PyObject* GraphShape(PyObject* tensor) const {
+ PyObject* arg_list = Py_BuildValue("(O)", tensor);
+ PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
+ Py_DECREF(arg_list);
+ return result;
+ }
+
+ tensorflow::Status CallBackwardFunction(
+ PyBackwardFunction* backward_function,
+ tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+ std::vector<PyObject*>* result) const final {
+ PyObject* grads = PyTuple_New(output_gradients.size());
+ for (int i = 0; i < output_gradients.size(); ++i) {
+ if (output_gradients[i] == nullptr) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(grads, i, Py_None);
+ } else {
+ PyTuple_SET_ITEM(grads, i,
+ reinterpret_cast<PyObject*>(output_gradients[i]));
+ }
+ }
+ PyObject* py_result = (*backward_function)(grads);
+ Py_DECREF(grads);
+ if (py_result == nullptr) {
+ return tensorflow::errors::Internal("gradient function threw exceptions");
+ }
+ result->clear();
+ PyObject* seq =
+ PySequence_Fast(py_result, "expected a sequence of gradients");
+ if (seq == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "gradient function did not return a list");
+ }
+ int len = PySequence_Fast_GET_SIZE(seq);
+ VLOG(1) << "Gradient length is " << len;
+ result->reserve(len);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+ if (item == Py_None) {
+ result->push_back(nullptr);
+ } else {
+ Py_INCREF(item);
+ result->push_back(item);
+ }
+ }
+ Py_DECREF(seq);
+ Py_DECREF(py_result);
+ return tensorflow::Status::OK();
+ }
+
+ void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
+
+ private:
+ PyObject* py_vspace_;
+
+ PyObject* num_elements_;
+ PyObject* aggregate_fn_;
+ PyObject* zeros_fn_;
+ PyObject* ones_fn_;
+ PyObject* graph_shape_fn_;
+};
+PyVSpace* py_vspace = nullptr;
+
+PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
+ if (py_vspace != nullptr) {
+ delete py_vspace;
+ }
+
+ py_vspace = new PyVSpace(e);
+ auto status = py_vspace->Initialize();
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ delete py_vspace;
+ return nullptr;
+ }
+
+ Py_RETURN_NONE;
+}
+
+PyObject* PyTapeTensor::GetShape() const {
+ if (shape_.index() == 0) {
+ auto& shape = absl::get<0>(shape_);
+ PyObject* py_shape = PyTuple_New(shape.dims());
+ for (int i = 0; i < shape.dims(); ++i) {
+ PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+ }
+
+ return py_shape;
+ }
+
+ return py_vspace->GraphShape(absl::get<1>(shape_));
+}
+
class GradientTape
- : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction> {
+ : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor> {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
- : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction>(
- persistent),
+ : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
+ PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {
@@ -1175,24 +1403,41 @@ void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
-static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+bool ListContainsNone(PyObject* list) {
+ if (list == Py_None) return true;
+ tensorflow::Safe_PyObjectPtr seq(
+ PySequence_Fast(list, "expected a sequence"));
+ if (seq == nullptr) {
+ return false;
+ }
+
+ int len = PySequence_Size(list);
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(seq.get(), i);
+ if (item == Py_None) return true;
+ }
+
+ return false;
+}
+
+static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
- tensorflow::int64 id = EagerTensor_id(tensor);
+ tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::TensorShape tensor_shape;
const tensorflow::Status status = t->handle->Shape(&tensor_shape);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype,
- tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
} else {
- return tensorflow::eager::TapeTensor{id, t->handle->dtype, tensor_shape};
+ return PyTapeTensor(id, t->handle->dtype, tensor_shape);
}
}
tensorflow::int64 id = FastTensorId(tensor);
if (PyErr_Occurred()) {
- return tensorflow::eager::TapeTensor{
- id, static_cast<tensorflow::DataType>(0), tensorflow::TensorShape({})};
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
@@ -1200,16 +1445,21 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
Py_DECREF(dtype_enum);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
static char _shape_tuple[] = "_shape_tuple";
PyObject* shape_tuple = PyObject_CallMethod(tensor, _shape_tuple, nullptr);
- if (PyErr_Occurred() != nullptr) {
- return tensorflow::eager::TapeTensor{id, dtype,
- tensorflow::TensorShape({})};
+ if (PyErr_Occurred()) {
+ return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
+ tensorflow::TensorShape({}));
}
+
+ if (ListContainsNone(shape_tuple)) {
+ return PyTapeTensor(id, dtype, tensor);
+ }
+
auto l = MakeIntList(shape_tuple);
Py_DECREF(shape_tuple);
// Replace -1, which represents accidental Nones which can occur in graph mode
@@ -1220,7 +1470,7 @@ static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
}
}
tensorflow::TensorShape shape(l);
- return tensorflow::eager::TapeTensor{id, dtype, shape};
+ return PyTapeTensor(id, dtype, shape);
}
std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
@@ -1286,7 +1536,7 @@ void TapeSetRecordOperation(
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer) {
- std::vector<tensorflow::eager::TapeTensor> output_info;
+ std::vector<PyTapeTensor> output_info;
PyObject* seq = PySequence_Fast(output_tensors,
"expected a sequence of integer tensor ids");
int len = PySequence_Size(output_tensors);
@@ -1362,173 +1612,6 @@ void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
}
}
-class PyVSpace
- : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction> {
- public:
- explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
- Py_INCREF(py_vspace_);
- }
-
- tensorflow::Status Initialize() {
- num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
- if (num_elements_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
- if (aggregate_fn_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
- if (zeros_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- ones_ =
- PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
- if (ones_ == nullptr) {
- return tensorflow::errors::InvalidArgument("invalid vspace");
- }
- return tensorflow::Status::OK();
- }
-
- ~PyVSpace() override {
- Py_XDECREF(num_elements_);
- Py_XDECREF(aggregate_fn_);
- Py_XDECREF(zeros_);
- Py_XDECREF(ones_);
-
- Py_DECREF(py_vspace_);
- }
-
- tensorflow::int64 NumElements(PyObject* tensor) const final {
- PyObject* arglist =
- Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
- PyObject* result = PyEval_CallObject(num_elements_, arglist);
- tensorflow::int64 r = MakeInt(result);
- Py_DECREF(result);
- Py_DECREF(arglist);
- return r;
- }
-
- PyObject* AggregateGradients(
- tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
- PyObject* list = PyList_New(gradient_tensors.size());
- for (int i = 0; i < gradient_tensors.size(); ++i) {
- // Note: stealing a reference to the gradient tensors.
- CHECK(gradient_tensors[i] != nullptr);
- CHECK(gradient_tensors[i] != Py_None);
- PyList_SET_ITEM(list, i,
- reinterpret_cast<PyObject*>(gradient_tensors[i]));
- }
- PyObject* arglist = Py_BuildValue("(O)", list);
- CHECK(arglist != nullptr);
- PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
- Py_DECREF(arglist);
- Py_DECREF(list);
- return result;
- }
-
- void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
-
- PyObject* Zeros(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(zeros_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return reinterpret_cast<PyObject*>(result);
- }
-
- PyObject* Ones(tensorflow::TensorShape shape,
- tensorflow::DataType dtype) const final {
- PyObject* py_shape = PyTuple_New(shape.dims());
- for (int i = 0; i < shape.dims(); ++i) {
- PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
- }
- PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
- PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
- PyObject* result = PyEval_CallObject(ones_, arg_list);
- Py_DECREF(arg_list);
- Py_DECREF(py_dtype);
- Py_DECREF(py_shape);
- return result;
- }
-
- tensorflow::Status CallBackwardFunction(
- PyBackwardFunction* backward_function,
- tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
- std::vector<PyObject*>* result) const final {
- PyObject* grads = PyTuple_New(output_gradients.size());
- for (int i = 0; i < output_gradients.size(); ++i) {
- if (output_gradients[i] == nullptr) {
- Py_INCREF(Py_None);
- PyTuple_SET_ITEM(grads, i, Py_None);
- } else {
- PyTuple_SET_ITEM(grads, i,
- reinterpret_cast<PyObject*>(output_gradients[i]));
- }
- }
- PyObject* py_result = (*backward_function)(grads);
- Py_DECREF(grads);
- if (py_result == nullptr) {
- return tensorflow::errors::Internal("gradient function threw exceptions");
- }
- result->clear();
- PyObject* seq =
- PySequence_Fast(py_result, "expected a sequence of gradients");
- if (seq == nullptr) {
- return tensorflow::errors::InvalidArgument(
- "gradient function did not return a list");
- }
- int len = PySequence_Fast_GET_SIZE(seq);
- VLOG(1) << "Gradient length is " << len;
- result->reserve(len);
- for (int i = 0; i < len; ++i) {
- PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
- if (item == Py_None) {
- result->push_back(nullptr);
- } else {
- Py_INCREF(item);
- result->push_back(item);
- }
- }
- Py_DECREF(seq);
- Py_DECREF(py_result);
- return tensorflow::Status::OK();
- }
-
- void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
-
- private:
- PyObject* py_vspace_;
-
- PyObject* num_elements_;
- PyObject* aggregate_fn_;
- PyObject* zeros_;
- PyObject* ones_;
-};
-PyVSpace* py_vspace = nullptr;
-
-PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
- if (py_vspace != nullptr) {
- delete py_vspace;
- }
-
- py_vspace = new PyVSpace(e);
- auto status = py_vspace->Initialize();
- if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
- delete py_vspace;
- return nullptr;
- }
-
- Py_RETURN_NONE;
-}
-
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
@@ -1740,6 +1823,9 @@ PyObject* MaybeGetDTypeForAttr(const string& attr,
Py_RETURN_NONE;
}
+// TODO(agarwal): use an automatic mechanism for handling None arguments to
+// gradient functions.
+
// Returns a pair where the first value of the pair indicates whether or not all
// outputs are unused. If the first value is false, the second value is a
// set that identifies which of the output indices are unused.
@@ -2565,13 +2651,18 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) {
int num_retvals = 0;
for (int i = 0; i < op_def->output_arg_size(); i++) {
const auto& output_arg = op_def->output_arg(i);
+ int delta = 1;
if (!output_arg.number_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.number_attr()];
+ delta = attr_list_sizes[output_arg.number_attr()];
} else if (!output_arg.type_list_attr().empty()) {
- num_retvals += attr_list_sizes[output_arg.type_list_attr()];
- } else {
- num_retvals++;
+ delta = attr_list_sizes[output_arg.type_list_attr()];
+ }
+ if (delta < 0) {
+ RaiseFallbackException(
+ "Attributes suggest that the size of an output list is less than 0");
+ return nullptr;
}
+ num_retvals += delta;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
diff --git a/tensorflow/python/eager/pywrap_tfe_test.py b/tensorflow/python/eager/pywrap_tfe_test.py
index fd8ab695b8..669fa08488 100644
--- a/tensorflow/python/eager/pywrap_tfe_test.py
+++ b/tensorflow/python/eager/pywrap_tfe_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import core
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -123,8 +124,8 @@ class Tests(test.TestCase):
def testFastpathExecute_MixedPrecisionVariableTapeWrite(self):
ctx = context.context()
with backprop.GradientTape(persistent=True) as tape:
- a_2_by_2 = constant_op.constant(
- [[1.0, 2.0], [3.0, 4.0]], dtype=dtypes.float32)
+ a_2_by_2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]],
+ dtype=dtypes.float32)
a_2_by_2_fp16 = math_ops.cast(a_2_by_2, dtype=dtypes.float16)
m1 = resource_variable_ops.ResourceVariable(a_2_by_2)
m2 = resource_variable_ops._MixedPrecisionVariable(
@@ -233,6 +234,26 @@ class Tests(test.TestCase):
pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
ctx_handle, None, [], a_2_by_2)
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testFastPathExecute_InvalidAttributes(self):
+ split_dim = constant_op.constant(0, dtype=dtypes.int32)
+ value = constant_op.constant([0, 1, 2, 3], dtype=dtypes.float32)
+ ctx = context.context()
+ ctx_handle = ctx._handle
+ with self.assertRaises(core._FallbackException):
+ pywrap_tensorflow.TFE_Py_FastPathExecute(ctx_handle, ctx.device_name,
+ "Split", None, None, split_dim,
+ value, "num_split", -1)
+
+ @test_util.assert_no_new_tensors
+ @test_util.assert_no_garbage_created
+ def testInvalidNumOutputs(self):
+ with self.assertRaisesRegexp(
+ Exception,
+ "Value for attr 'num_split' of -1 must be at least minimum 1"):
+ array_ops.split(value=[1, 2, 3], num_or_size_splits=-1)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 4001ffdd6b..7f2349954d 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -197,6 +197,7 @@ py_library(
srcs = ["canned/boosted_trees.py"],
srcs_version = "PY2AND3",
deps = [
+ ":boosted_trees_utils",
":estimator",
":head",
":model_fn",
@@ -224,6 +225,35 @@ py_test(
)
py_library(
+ name = "boosted_trees_utils",
+ srcs = ["canned/boosted_trees_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":estimator",
+ ":head",
+ ":model_fn",
+ "//tensorflow:tensorflow_py_no_contrib",
+ ],
+)
+
+py_test(
+ name = "boosted_trees_utils_test",
+ size = "medium",
+ srcs = ["canned/boosted_trees_utils_test.py"],
+ shard_count = 2,
+ srcs_version = "PY2AND3",
+ tags = [
+ "optonly",
+ ],
+ deps = [
+ ":boosted_trees",
+ ":inputs",
+ "//tensorflow:tensorflow_py_no_contrib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "dnn",
srcs = ["canned/dnn.py"],
srcs_version = "PY2AND3",
@@ -685,6 +715,7 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_windows",
+ "notsan", # b/67510291
],
deps = [
":keras",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 19f18015e4..0278990cfc 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -21,8 +21,12 @@ import abc
import collections
import functools
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import boosted_trees_utils
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import dtypes
@@ -36,8 +40,10 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.array_ops import identity as tf_identity
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -191,14 +197,50 @@ def _calculate_num_features(sorted_feature_columns):
return num_features
+def _generate_feature_name_mapping(sorted_feature_columns):
+ """Return a list of feature name for feature ids.
+
+ Args:
+ sorted_feature_columns: a list/set of tf.feature_column sorted by name.
+
+ Returns:
+ feature_name_mapping: a list of feature names indexed by the feature ids.
+
+ Raises:
+ ValueError: when unsupported features/columns are tried.
+ """
+ names = []
+ for column in sorted_feature_columns:
+ if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access
+ categorical_column = column.categorical_column
+ if isinstance(categorical_column,
+ feature_column_lib._VocabularyListCategoricalColumn): # pylint:disable=protected-access
+ for value in categorical_column.vocabulary_list:
+ names.append('{}:{}'.format(column.name, value))
+ elif isinstance(categorical_column,
+ feature_column_lib._BucketizedColumn): # pylint:disable=protected-access
+ boundaries = [-np.inf] + list(categorical_column.boundaries) + [np.inf]
+ for pair in zip(boundaries[:-1], boundaries[1:]):
+ names.append('{}:{}'.format(column.name, pair))
+ else:
+ for num in range(categorical_column._num_buckets): # pylint:disable=protected-access
+ names.append('{}:{}'.format(column.name, num))
+ elif isinstance(column, feature_column_lib._BucketizedColumn):
+ names.append(column.name)
+ else:
+ raise ValueError(
+ 'For now, only bucketized_column and indicator_column is supported '
+ 'but got: {}'.format(column))
+ return names
+
+
def _cache_transformed_features(features, sorted_feature_columns, batch_size):
"""Transform features and cache, then returns (cached_features, cache_op)."""
num_features = _calculate_num_features(sorted_feature_columns)
cached_features = [
_local_variable(
array_ops.zeros([batch_size], dtype=dtypes.int32),
- name='cached_feature_{}'.format(i))
- for i in range(num_features)
+ name='cached_feature_{}'.format(i)) for i in range(num_features)
]
are_features_cached = _local_variable(False, name='are_features_cached')
@@ -228,8 +270,7 @@ def _cache_transformed_features(features, sorted_feature_columns, batch_size):
return cached, cache_flip_op
input_feature_list, cache_flip_op = control_flow_ops.cond(
- are_features_cached,
- lambda: (cached_features, control_flow_ops.no_op()),
+ are_features_cached, lambda: (cached_features, control_flow_ops.no_op()),
cache_features_and_return)
return input_feature_list, cache_flip_op
@@ -263,8 +304,8 @@ class _CacheTrainingStatesUsingHashTable(object):
elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype):
empty_key = ''
else:
- raise ValueError('Unsupported example_id_feature dtype %s.' %
- example_ids.dtype)
+ raise ValueError(
+ 'Unsupported example_id_feature dtype %s.' % example_ids.dtype)
# Cache holds latest <tree_id, node_id, logits> for each example.
# tree_id and node_id are both int32 but logits is a float32.
# To reduce the overhead, we store all of them together as float32 and
@@ -273,8 +314,8 @@ class _CacheTrainingStatesUsingHashTable(object):
empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3])
self._example_ids = ops.convert_to_tensor(example_ids)
if self._example_ids.shape.ndims not in (None, 1):
- raise ValueError('example_id should have rank 1, but got %s' %
- self._example_ids)
+ raise ValueError(
+ 'example_id should have rank 1, but got %s' % self._example_ids)
self._logits_dimension = logits_dimension
def lookup(self):
@@ -334,7 +375,7 @@ class _CacheTrainingStatesUsingVariables(object):
array_ops.zeros([batch_size], dtype=dtypes.int32),
name='tree_ids_cache')
self._node_ids = _local_variable(
- _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
name='node_ids_cache')
self._logits = _local_variable(
array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
@@ -422,9 +463,13 @@ class _EnsembleGrower(object):
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
- if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING
- and tree_hparams.tree_complexity <= 0):
- raise ValueError('For pruning, tree_complexity must be positive.')
+ if tree_hparams.tree_complexity > 0:
+ if self._pruning_mode_parsed == boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError(
+ 'Tree complexity have no effect unless pruning mode is chosen.')
+ else:
+ if self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError('For pruning, tree_complexity must be positive.')
# pylint: enable=protected-access
@abc.abstractmethod
@@ -719,7 +764,7 @@ def _bt_model_fn(
tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
# Create logits.
- if mode != model_fn.ModeKeys.TRAIN:
+ if mode != model_fn_lib.ModeKeys.TRAIN:
input_feature_list = _get_transformed_features(features,
sorted_feature_columns)
logits = boosted_trees_ops.predict(
@@ -886,6 +931,7 @@ def _bt_model_fn(
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
+
# Add an early stop hook.
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks +
@@ -927,8 +973,8 @@ def _create_classification_head_and_closed_form(n_classes, weight_column,
label_vocabulary):
"""Creates a head for classifier and the closed form gradients/hessians."""
head = _create_classification_head(n_classes, weight_column, label_vocabulary)
- if (n_classes == 2 and head.logits_dimension == 1 and weight_column is None
- and label_vocabulary is None):
+ if (n_classes == 2 and head.logits_dimension == 1 and
+ weight_column is None and label_vocabulary is None):
# Use the closed-form gradients/hessians for 2 class.
def _grad_and_hess_for_logloss(logits, labels):
"""A closed form gradient and hessian for logistic loss."""
@@ -961,8 +1007,282 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
+def _compute_feature_importances_per_tree(tree, num_features):
+ """Computes the importance of each feature in the tree."""
+ importances = np.zeros(num_features)
+
+ for node in tree.nodes:
+ node_type = node.WhichOneof('node')
+ if node_type == 'bucketized_split':
+ feature_id = node.bucketized_split.feature_id
+ importances[feature_id] += node.metadata.gain
+ elif node_type == 'leaf':
+ assert node.metadata.gain == 0
+ else:
+ raise ValueError('Unexpected split type %s', node_type)
+
+ return importances
+
+
+def _compute_feature_importances(tree_ensemble, num_features, normalize):
+ """Computes gain-based feature importances.
+
+ The higher the value, the more important the feature.
+
+ Args:
+ tree_ensemble: a trained tree ensemble, instance of proto
+ boosted_trees.TreeEnsemble.
+ num_features: The total number of feature ids.
+ normalize: If True, normalize the feature importances.
+
+ Returns:
+ sorted_feature_idx: A list of feature_id which is sorted
+ by its feature importance.
+ feature_importances: A list of corresponding feature importances.
+
+ Raises:
+ AssertionError: When normalize = True, if feature importances
+ contain negative value, or if normalization is not possible
+ (e.g. ensemble is empty or trees contain only a root node).
+ """
+ tree_importances = [_compute_feature_importances_per_tree(tree, num_features)
+ for tree in tree_ensemble.trees]
+ tree_importances = np.array(tree_importances)
+ tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1)
+ feature_importances = np.sum(tree_importances * tree_weights, axis=0)
+ if normalize:
+ assert np.all(feature_importances >= 0), ('feature_importances '
+ 'must be non-negative.')
+ normalizer = np.sum(feature_importances)
+ assert normalizer > 0, 'Trees are all empty or contain only a root node.'
+ feature_importances /= normalizer
+
+ sorted_feature_idx = np.argsort(feature_importances)[::-1]
+ return sorted_feature_idx, feature_importances[sorted_feature_idx]
+
+
+def _bt_explanations_fn(features,
+ head,
+ sorted_feature_columns,
+ name='boosted_trees'):
+ """Gradient Boosted Trees predict with explanations model_fn.
+
+ Args:
+ features: dict of `Tensor`.
+ head: A `head_lib._Head` instance.
+ sorted_feature_columns: Sorted iterable of `feature_column._FeatureColumn`
+ model inputs.
+ name: Name used for the model.
+
+ Returns:
+ An `EstimatorSpec` instance.
+
+ Raises:
+ ValueError: mode or params are invalid, or features has the wrong type.
+ """
+ mode = model_fn_lib.ModeKeys.PREDICT
+ with ops.name_scope(name) as name:
+ # Create Ensemble resources.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+
+ input_feature_list = _get_transformed_features(features,
+ sorted_feature_columns)
+
+ logits = boosted_trees_ops.predict(
+ # For non-TRAIN mode, ensemble doesn't change after initialization,
+ # so no local copy is needed; using tree_ensemble directly.
+ tree_ensemble_handle=tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=None,
+ train_op_fn=control_flow_ops.no_op,
+ logits=logits)
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+ estimator_spec.predictions[boosted_trees_utils._DEBUG_PROTO_KEY] = debug_op # pylint: disable=protected-access
+ return estimator_spec
+
+
+class _BoostedTreesBase(estimator.Estimator):
+ """Base class for boosted trees estimators.
+
+ This class is intended to keep tree-specific functions (E.g., methods for
+ feature importances and directional feature contributions) in one central
+ place.
+
+ It is not a valid (working) Estimator on its own and should only be used as a
+ base class.
+ """
+
+ def __init__(self, model_fn, model_dir, config, feature_columns, head,
+ center_bias, is_classification):
+ """Initializes a `_BoostedTreesBase` instance.
+
+ Args:
+ model_fn: model_fn: Model function. See base class for more detail.
+ model_dir: Directory to save model parameters, graph and etc. See base
+ class for more detail.
+ config: `estimator.RunConfig` configuration object.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`
+ head: A `head_lib._Head` instance.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
+ is_classification: If the estimator is for classification.
+ """
+ super(_BoostedTreesBase, self).__init__(
+ model_fn=model_fn, model_dir=model_dir, config=config)
+ self._sorted_feature_columns = sorted(
+ feature_columns, key=lambda tc: tc.name)
+ self._head = head
+ self._n_features = _calculate_num_features(self._sorted_feature_columns)
+ self._names_for_feature_id = np.array(
+ _generate_feature_name_mapping(self._sorted_feature_columns))
+ self._center_bias = center_bias
+ self._is_classification = is_classification
+
+ def experimental_feature_importances(self, normalize=False):
+ """Computes gain-based feature importances.
+
+ The higher the value, the more important the corresponding feature.
+
+ Args:
+ normalize: If True, normalize the feature importances.
+
+ Returns:
+ sorted_feature_names: 1-D array of feature name which is sorted
+ by its feature importance.
+ feature_importances: 1-D array of the corresponding feature importance.
+
+ Raises:
+ ValueError: When attempting to normalize on an empty ensemble
+ or an ensemble of trees which have no splits. Or when attempting
+ to normalize and feature importances have negative values.
+ """
+ reader = checkpoint_utils.load_checkpoint(self._model_dir)
+ serialized = reader.get_tensor('boosted_trees:0_serialized')
+ if not serialized:
+ raise ValueError('Found empty serialized string for TreeEnsemble.'
+ 'You should only call this method after training.')
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+
+ sorted_feature_id, importances = _compute_feature_importances(
+ ensemble_proto, self._n_features, normalize)
+ return self._names_for_feature_id[sorted_feature_id], importances
+
+ def experimental_predict_with_explanations(self,
+ input_fn,
+ predict_keys=None,
+ hooks=None,
+ checkpoint_path=None):
+ """Computes model explainability outputs per example along with predictions.
+
+ Currently supports directional feature contributions (DFCs). For each
+ instance, DFCs indicate the aggregate contribution of each feature. See
+ https://arxiv.org/abs/1312.1121 and
+ http://blog.datadive.net/interpreting-random-forests/ for more details.
+ Args:
+ input_fn: A function that provides input data for predicting as
+ minibatches. See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
+ the following: * A `tf.data.Dataset` object: Outputs of `Dataset`
+ object must be a tuple `(features, labels)` with same constraints as
+ below. * A tuple `(features, labels)`: Where `features` is a `tf.Tensor`
+ or a dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
+ predict_keys: list of `str`, name of the keys to predict. It is used if
+ the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
+ `predict_keys` is used then rest of the predictions will be filtered
+ from the dictionary, with the exception of 'bias' and 'dfc', which will
+ always be in the dictionary. If `None`, returns all keys in prediction
+ dict, as well as two new keys 'dfc' and 'bias'.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the prediction call.
+ checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
+ latest checkpoint in `model_dir` is used. If there are no checkpoints
+ in `model_dir`, prediction is run with newly initialized `Variables`
+ instead of ones restored from checkpoint.
+
+ Yields:
+ Evaluated values of `predictions` tensors. The `predictions` tensors will
+ contain at least two keys 'dfc' and 'bias' for model explanations. The
+ `dfc` value corresponds to the contribution of each feature to the overall
+ prediction for this instance (positive indicating that the feature makes
+ it more likely to select class 1 and negative less likely). The 'bias'
+ value will be the same across all the instances, corresponding to the
+ probability (classification) or prediction (regression) of the training
+ data distribution.
+
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
+ """
+ if not self._center_bias:
+ raise ValueError('center_bias must be enabled during estimator '
+ 'instantiation when using '
+ 'experimental_predict_with_explanations.')
+ # pylint: disable=protected-access
+ if not self._is_classification:
+ identity_inverse_link_fn = self._head._inverse_link_fn in (None,
+ tf_identity)
+ # pylint:enable=protected-access
+ if not identity_inverse_link_fn:
+ raise ValueError(
+ 'For now only identity inverse_link_fn in regression_head is '
+ 'supported for experimental_predict_with_explanations.')
+
+ # pylint:disable=unused-argument
+ def new_model_fn(features, labels, mode):
+ return _bt_explanations_fn(features, self._head,
+ self._sorted_feature_columns)
+
+ # pylint:enable=unused-argument
+ est = estimator.Estimator(
+ model_fn=new_model_fn,
+ model_dir=self.model_dir,
+ config=self.config,
+ warm_start_from=self._warm_start_settings)
+ # Make sure bias and dfc will be in prediction dict.
+ user_supplied_predict_keys = predict_keys is not None
+ if user_supplied_predict_keys:
+ predict_keys = set(predict_keys)
+ predict_keys.add(boosted_trees_utils._DEBUG_PROTO_KEY)
+ predictions = est.predict(
+ input_fn,
+ predict_keys=predict_keys,
+ hooks=hooks,
+ checkpoint_path=checkpoint_path,
+ yield_single_examples=True)
+ for pred in predictions:
+ bias, dfcs = boosted_trees_utils._parse_explanations_from_prediction(
+ pred[boosted_trees_utils._DEBUG_PROTO_KEY], self._n_features,
+ self._is_classification)
+ pred['bias'] = bias
+ pred['dfc'] = dfcs
+ # Don't need to expose serialized proto to end user.
+ del pred[boosted_trees_utils._DEBUG_PROTO_KEY]
+ yield pred
+
+
+# pylint: disable=protected-access
@estimator_export('estimator.BoostedTreesClassifier')
-class BoostedTreesClassifier(estimator.Estimator):
+class BoostedTreesClassifier(_BoostedTreesBase):
"""A Classifier for Tensorflow Boosted Trees models.
@compatibility(eager)
@@ -1082,14 +1402,13 @@ class BoostedTreesClassifier(estimator.Estimator):
n_classes = 2
head, closed_form = _create_classification_head_and_closed_form(
n_classes, weight_column, label_vocabulary=label_vocabulary)
-
# HParams for the model.
tree_hparams = _TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
- return _bt_model_fn( # pylint: disable=protected-access
+ return _bt_model_fn(
features,
labels,
mode,
@@ -1101,11 +1420,17 @@ class BoostedTreesClassifier(estimator.Estimator):
closed_form_grad_and_hess_fn=closed_form)
super(BoostedTreesClassifier, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=True)
@estimator_export('estimator.BoostedTreesRegressor')
-class BoostedTreesRegressor(estimator.Estimator):
+class BoostedTreesRegressor(_BoostedTreesBase):
"""A Regressor for Tensorflow Boosted Trees models.
@compatibility(eager)
@@ -1223,9 +1548,17 @@ class BoostedTreesRegressor(estimator.Estimator):
tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
- return _bt_model_fn( # pylint: disable=protected-access
- features, labels, mode, head, feature_columns, tree_hparams,
- n_batches_per_layer, config)
+ return _bt_model_fn(features, labels, mode, head, feature_columns,
+ tree_hparams, n_batches_per_layer, config)
super(BoostedTreesRegressor, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=False)
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 6e28c72151..23687a738b 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -17,9 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
+from google.protobuf import text_format
import numpy as np
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator import run_config
@@ -31,10 +35,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import boosted_trees_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
NUM_FEATURES = 3
@@ -564,6 +570,704 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+ def testFeatureImportancesWithTrainedEnsemble(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ # It will stop after 5 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.833933, 0.606342, 0.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.579010, 0.420990, 0.0], importances)
+
+ def testFeatureImportancesOnEmptyEnsemble(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ class BailOutWithoutTraining(session_run_hook.SessionRunHook):
+
+ def before_run(self, run_context):
+ raise StopIteration('to bail out.')
+
+ # The step-0 checkpoint will have only an empty ensemble.
+ est.train(input_fn,
+ steps=100, # must stop at 0 anyway.
+ hooks=[BailOutWithoutTraining()])
+
+ with self.assertRaisesRegexp(ValueError, 'empty serialized string'):
+ est.experimental_feature_importances(normalize=False)
+
+ with self.assertRaisesRegexp(ValueError, 'empty serialized string'):
+ est.experimental_feature_importances(normalize=True)
+
+ def _create_fake_checkpoint_with_tree_ensemble_proto(self,
+ est,
+ tree_ensemble_text):
+ with ops.Graph().as_default():
+ with ops.name_scope('boosted_trees') as name:
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+ tree_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(tree_ensemble_text, tree_ensemble_proto)
+ stamp_token, _ = tree_ensemble.serialize()
+ restore_op = tree_ensemble.deserialize(
+ stamp_token, tree_ensemble_proto.SerializeToString())
+
+ with session.Session() as sess:
+ resources.initialize_resources(resources.shared_resources()).run()
+ restore_op.run()
+ saver = saver_lib.Saver()
+ save_path = os.path.join(est.model_dir, 'model.ckpt')
+ saver.save(sess, save_path)
+
+ def testFeatureImportancesOnNonEmptyEnsemble(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 3.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 7
+ right_id: 8
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 1.0 * [3 + 1, 2, 2] + 1.0 * [1, 1, 0]
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
+ def testFeatureImportancesWithTreeWeights(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=3,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 12.5
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 0.4
+ tree_weights: 0.6
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 0.4 * [12.5, 0, 5] + 0.6 * [0, 5, 0] + 1.0 * [0, 0, 0]
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
+ def testFeatureImportancesWithAllEmptyTree(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ # Reverse order because feature importances are sorted by np.argsort(f)[::-1]
+ feature_names_expected = ['f_2_bucketized',
+ 'f_1_bucketized',
+ 'f_0_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, 0.0], importances)
+
+ with self.assertRaisesRegexp(AssertionError,
+ 'all empty or contain only a root node'):
+ est.experimental_feature_importances(normalize=True)
+
+ def testNegativeFeatureImportances(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ # In order to generate a negative feature importances,
+ # We assign an invalid value -1 to tree_weights here.
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: -1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ # Github #21509 (nataliaponomareva):
+ # The gains stored in the splits can be negative
+ # if people are using complexity regularization.
+ feature_names_expected = ['f_2_bucketized',
+ 'f_0_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, -5.0], importances)
+
+ with self.assertRaisesRegexp(AssertionError, 'non-negative'):
+ est.experimental_feature_importances(normalize=True)
+
+ def testFeatureImportancesNamesForCategoricalColumn(self):
+ categorical = feature_column.categorical_column_with_vocabulary_list(
+ key='categorical', vocabulary_list=('bad', 'good', 'ok'))
+ feature_indicator = feature_column.indicator_column(categorical)
+ bucketized_col = feature_column.bucketized_column(
+ feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ bucketized_indicator = feature_column.indicator_column(bucketized_col)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[feature_indicator,
+ bucketized_col,
+ bucketized_indicator],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 5
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -2.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 4.34
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['categorical_indicator:ok',
+ 'continuous_bucketized_indicator:(-2.0, 0.5)',
+ 'continuous_bucketized_indicator:(-inf, -2.0)',
+ 'categorical_indicator:bad',
+ # Reverse order because feature importances
+ # are sorted by np.argsort(f)[::-1]
+ 'continuous_bucketized_indicator:(12.0, inf)',
+ 'continuous_bucketized_indicator:(0.5, 12.0)',
+ 'continuous_bucketized',
+ 'categorical_indicator:good']
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 1.0 * [5, 0, 2, 0, 0, 0, 0, 0] + 1.0 * [0, 2, 0, 1, 0, 0, 0, 0]
+ self.assertAllClose([5.0, 2.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.2, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0], importances)
+
+ def testFeatureImportancesNamesForUnsupportedColumn(self):
+ numeric_col = feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'only bucketized_column and indicator_column'):
+ _ = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[numeric_col],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
+
+ def testTreeComplexityIsSetCorrectly(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ num_steps = 10
+ # Tree complexity is set but no pruning.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ tree_complexity=1e-3)
+ with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'):
+ est.train(input_fn, steps=num_steps)
+
+ # Pruning but no tree complexity.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre')
+ with self.assertRaisesRegexp(ValueError,
+ 'tree_complexity must be positive'):
+ est.train(input_fn, steps=num_steps)
+
+ # All is good.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=1e-3)
+ est.train(input_fn, steps=num_steps)
+
+
+class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase):
+ """Test debug/model explainability outputs for individual predictions.
+
+ Includes directional feature contributions (DFC).
+ """
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+
+ def testBinaryClassifierThatDFCIsInPredictions(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=3, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([0.4] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.12108613453574479,
+ 1: 0.0,
+ 2: -0.039254929814481143
+ }, {
+ 0: 0.19650601422250574,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }, {
+ 0: 0.16057487356133376,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }, {
+ 0: -0.12108613453574479,
+ 1: 0.0,
+ 2: -0.039254929814481143
+ }, {
+ 0: -0.10832468554550384,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == probabilities.
+ expected_probabilities = [
+ 0.23965894, 0.62344426, 0.58751315, 0.23965894, 0.31861359
+ ]
+ probabilities = [
+ sum(dfc.values()) + bias for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_probabilities, probabilities)
+
+ # When user doesn't include bias or dfc in predict_keys, make sure to still
+ # include dfc and bias.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['probabilities'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('probabilities' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+ def testRegressorThatDFCIsInPredictions(self):
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([1.8] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.070499420166015625,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: -0.53763031959533691,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: -0.51756942272186279,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: 0.1563495397567749,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: 0.96934974193572998,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == predictions.
+ expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+ [2.01968288], [2.83268309]]
+ predictions = [
+ [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_predictions, predictions)
+
+ # Test when user doesn't include bias or dfc in predict_keys.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['predictions'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('predictions' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
class ModelFnTests(test_util.TensorFlowTestCase):
"""Tests bt_model_fn including unexposed internal functionalities."""
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils.py b/tensorflow/python/estimator/canned/boosted_trees_utils.py
new file mode 100644
index 0000000000..85efc2304a
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debug and model explainability logic for boosted trees."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+
+# For directional feature contributions.
+_DEBUG_PROTO_KEY = '_serialized_debug_outputs_proto'
+_BIAS_ID = 0
+
+
+def _parse_debug_proto_string(example_proto_serialized):
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example_proto_serialized)
+ feature_ids = example_debug_outputs.feature_ids
+ logits_path = example_debug_outputs.logits_path
+ return feature_ids, logits_path
+
+
+def _compute_directional_feature_contributions(example_feature_ids,
+ example_logits_paths, activation,
+ num_bucketized_features):
+ """Directional feature contributions and bias, per example."""
+ # Initialize contributions to 0.
+ dfcs = {k: 0 for k in range(num_bucketized_features)}
+
+ # Traverse tree subtracting child prediction from parent prediction and
+ # associating change with feature id used to split.
+ predictions = np.array(activation(example_logits_paths))
+ delta_pred = predictions[_BIAS_ID + 1:] - predictions[:-1]
+ # Group by feature id, then sum delta_pred.
+ contribs = np.bincount(
+ example_feature_ids,
+ weights=delta_pred,
+ minlength=num_bucketized_features)
+ for f, dfc in zip(range(num_bucketized_features), contribs):
+ dfcs[f] = dfc
+ return predictions[_BIAS_ID], dfcs
+
+
+def _identity(logits):
+ return logits
+
+
+def _sigmoid(logits):
+ # TODO(crawles): Change to softmax once multiclass support is available.
+ return 1 / (1 + np.exp(-np.array(logits)))
+
+
+def _parse_explanations_from_prediction(serialized_debug_proto,
+ n_features,
+ classification=False):
+ """Parse serialized explanability proto, compute dfc, and return bias, dfc."""
+ feature_ids, logits_path = _parse_debug_proto_string(serialized_debug_proto)
+ if classification:
+ activation = _sigmoid
+ else:
+ activation = _identity
+ bias, dfcs = _compute_directional_feature_contributions(
+ feature_ids, logits_path, activation, n_features)
+ # TODO(crawles): Prediction path and leaf IDs.
+ return bias, dfcs
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils_test.py b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
new file mode 100644
index 0000000000..506d4ea6fb
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
@@ -0,0 +1,187 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators and model_fn."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator.canned import boosted_trees_utils
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class BoostedTreesDFCTest(test_util.TensorFlowTestCase):
+ """Test directional feature contributions (DFC) helper functions. """
+
+ def testDirectionalFeatureContributionsCompute(self):
+ """Tests logic to compute DFCs given feature ids and logits paths."""
+ num_bucketized_features = 3 # Includes one unused feature.
+ examples_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
+ e1_feature_ids, e2_feature_ids = examples_feature_ids
+
+ # DFCs are computed by traversing the prediction path and subtracting each
+ # child prediction from its parent prediction and associating the change in
+ # prediction with the respective feature id used for the split.
+ # For each activation function, f, (currently identity or sigmoid), DFCs are
+ # calculated for the two examples as:
+ # example 1:
+ # feature_0 = (f(1.114) - f(1.214)) + (f(6.114) - f(1.114))
+ # feature_1 = 0 # Feature not in ensemble, thus zero contrib.
+ # feature_2 = (f(0.114) - bias_pred) + (f(1.214) - f(0.114))
+ # example 2:
+ # feature_0 = f(-5.486) - f(1.514)
+ # feature_1 = 0 # Feature not in ensemble, thus zero contrib.
+ # feature_2 = (f(0.114) - bias_pred) + (f(1.514) - f(0.114))
+ # where bias_pred is = f(0) or f(0.21), with center_bias = {True, False},
+ # respectively.
+ # Keys are center_bias.
+ expected_dfcs_identity = {
+ False: ({
+ 0: 4.9,
+ 1: 0,
+ 2: 1.214
+ }, {
+ 0: -7.0,
+ 1: 0,
+ 2: 1.514
+ }),
+ True: ({
+ 0: 4.9,
+ 1: 0,
+ 2: 1.0039999999999998
+ }, {
+ 0: -7.0,
+ 1: 0,
+ 2: 1.3039999999999998
+ })
+ }
+ expected_dfcs_sigmoid = {
+ False: ({
+ 0: 0.22678725678805578,
+ 1: 0,
+ 2: 0.2710059376234506
+ }, {
+ 0: -0.81552596670046507,
+ 1: 0,
+ 2: 0.319653250251275
+ }),
+ True: ({
+ 0: 0.22678725678805578,
+ 1: 0,
+ 2: 0.2186980280491253
+ }, {
+ 0: -0.81552596670046507,
+ 1: 0,
+ 2: 0.26734534067694971
+ })
+ }
+ # pylint: disable=protected-access
+ for f, expected_dfcs in zip(
+ (boosted_trees_utils._identity, boosted_trees_utils._sigmoid),
+ (expected_dfcs_identity, expected_dfcs_sigmoid)):
+ for center_bias in [False, True]:
+ # If not center_bias, the bias after activation is 0.
+ if center_bias:
+ bias_logit = 0.21 # Root node of tree_0.
+ else:
+ bias_logit = 0 # 0 is default value when there is no original_leaf.
+ f_bias = f(bias_logit)
+
+ # Logits before and after, as is outputed from
+ # boosted_trees_ops.example_debug_outputs
+ examples_logits_paths = ((bias_logit, 0.114, 1.214, 1.114, 6.114),
+ (bias_logit, 0.114, 1.514, -5.486))
+ e1_logits_path, e2_logits_path = examples_logits_paths
+ e1_expected_dfcs, e2_expected_dfcs = expected_dfcs[center_bias]
+ # Check feature contributions are correct for both examples.
+ # Example 1.
+ # pylint:disable=line-too-long
+ e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e1_feature_ids, e1_logits_path, f, num_bucketized_features)
+ self.assertAllClose(e1_bias, f_bias)
+ self.assertAllClose(e1_dfc, e1_expected_dfcs)
+ # Example 2.
+ e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e2_feature_ids, e2_logits_path, f, num_bucketized_features)
+ # pylint:enable=line-too-long
+ self.assertAllClose(e2_bias, f_bias)
+ self.assertAllClose(e2_dfc, e2_expected_dfcs)
+ # Check if contributions sum to final prediction.
+ # For each tree, get leaf of last tree.
+ expected_logits = (e1_logits_path[-1], e2_logits_path[-1])
+ # Predictions should be the sum of contributions + bias.
+ expected_preds = [f(logit) for logit in expected_logits]
+ e1_pred = e1_bias + sum(e1_dfc.values())
+ e2_pred = e2_bias + sum(e2_dfc.values())
+ preds = [e1_pred, e2_pred]
+ self.assertAllClose(preds, expected_preds)
+ # pylint: enable=protected-access
+
+ def testDFCComputeComparedToExternalExample(self):
+ """Tests `compute_dfc` compared to external example (regression).
+
+ Example from http://blog.datadive.net/interpreting-random-forests.
+ """
+ # DIS:3, RM: 2, LSTAT:1, NOX:0
+ num_bucketized_features = 4
+ e1_feature_ids = (2, 1, 0)
+ e2_feature_ids = (2, 2, 2)
+ e3_feature_ids = (2, 2, 0)
+
+ bias_logit = 22.60 # Root node of tree_0.
+ activation = boosted_trees_utils._identity
+ f_bias = activation(bias_logit)
+ # Logits before and after, as is outputed from
+ # boosted_trees_ops.example_debug_outputs
+ e1_logits_path = (bias_logit, 19.96, 14.91, 18.11)
+ e2_logits_path = (bias_logit, 37.42, 45.10, 45.90)
+ e3_logits_path = (bias_logit, 37.42, 32.30, 33.58)
+ e1_expected_dfcs = {0: 3.20, 1: -5.05, 2: -2.64, 3: 0}
+ e2_expected_dfcs = {0: 0, 1: 0, 2: 23.3, 3: 0}
+ e3_expected_dfcs = {0: 1.28, 1: 0, 2: 9.7, 3: 0}
+ # Check feature contributions are correct for both examples.
+ # Example 1.
+ # pylint: disable=protected-access
+ # pylint: disable=line-too-long
+ e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e1_feature_ids, e1_logits_path, activation, num_bucketized_features)
+ self.assertAllClose(e1_bias, f_bias)
+ self.assertAllClose(e1_dfc, e1_expected_dfcs)
+ # Example 2.
+ e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e2_feature_ids, e2_logits_path, activation, num_bucketized_features)
+ self.assertAllClose(e2_bias, f_bias)
+ self.assertAllClose(e2_dfc, e2_expected_dfcs)
+ # Example 3.
+ e3_bias, e3_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e3_feature_ids, e3_logits_path, activation, num_bucketized_features)
+ # pylint: enable=line-too-long
+ self.assertAllClose(e3_bias, f_bias)
+ self.assertAllClose(e3_dfc, e3_expected_dfcs)
+ # pylint: enable=protected-access
+ # Check if contributions sum to final prediction.
+ # For each tree, get leaf of last tree.
+ expected_logits = (18.11, 45.90, 33.58)
+ # Predictions should be the sum of contributions + bias.
+ expected_preds = [activation(logit) for logit in expected_logits]
+ e1_pred = e1_bias + sum(e1_dfc.values())
+ e2_pred = e2_bias + sum(e2_dfc.values())
+ e3_pred = e3_bias + sum(e3_dfc.values())
+ preds = [e1_pred, e2_pred, e3_pred]
+ self.assertAllClose(preds, expected_preds)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 0f20acefdf..eec64ad452 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -41,7 +41,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_util
-from tensorflow.python.keras import metrics
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -329,7 +328,7 @@ class Estimator(object):
run_config.TaskType.PS):
raise ValueError(
'Train has been called wrong configuration. Please use '
- 'tf.estimator.train_and_evaluate which calls propper API according '
+ 'tf.estimator.train_and_evaluate which calls proper API according '
'to given configuration. Current configuration: {}.'.format(
self.config))
@@ -490,6 +489,10 @@ class Estimator(object):
yield_single_examples=True):
"""Yields predictions for given features.
+ Please note that interleaving two predict outputs does not work. See:
+ [issue/20506](
+ https://github.com/tensorflow/tensorflow/issues/20506#issuecomment-422208517)
+
Args:
input_fn: A function that constructs the features. Prediction continues
until `input_fn` raises an end-of-input exception
@@ -611,7 +614,7 @@ class Estimator(object):
# pylint: disable=line-too-long,g-doc-args,g-doc-return-or-yield
"""Exports inference graph as a `SavedModel` into the given dir.
- Note that `export_to_savedmodel` will be renamed to `export_to_saved_model`
+ Note that `export_to_savedmodel` will be renamed to `export_saved_model`
in TensorFlow 2.0. At that time, `export_to_savedmodel` without the
additional underscore will be available only through tf.compat.v1.
@@ -696,7 +699,7 @@ class Estimator(object):
"""
# pylint: enable=line-too-long
# TODO(b/111442174): `export_to_savedmodel` will be renamed to
- # `export_to_saved_model` in TensorFlow 2.0. This function is a wrapper
+ # `export_saved_model` in TensorFlow 2.0. This function is a wrapper
# while staging the new version; do not add any logic here.
return self.export_savedmodel(
export_dir_base,
@@ -1653,7 +1656,7 @@ def _combine_distributed_scaffold(grouped_scaffold, distribution):
def _unwrap_and_concat(value):
value = nest.flatten(distribution.unwrap(value))
if len(value) != 1:
- return array_ops.concat(value)
+ return array_ops.concat(value, 0)
return value[0]
ready_op = distribution.call_for_each_tower(
@@ -1788,18 +1791,9 @@ def _extract_metric_update_ops(eval_dict, distribution=None):
value_ops = {}
# Sort metrics lexicographically so graph is identical every time.
for name, value in sorted(six.iteritems(eval_dict)):
- if isinstance(value, metrics.Metric):
- metric_result = value.result()
- # We expect only one update op for every metric when there is no
- # distribution strategy.
- metric_update = value.updates if distribution else value.updates[0]
- else:
- metric_result = value[0]
- metric_update = value[1]
-
- value_ops[name] = metric_result
+ value_ops[name] = value[0]
update_ops.append(
- distribution.group(metric_update) if distribution else metric_update)
+ distribution.group(value[1]) if distribution else value[1])
update_op = control_flow_ops.group(*update_ops) if update_ops else None
return update_op, value_ops
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index 3eed1ab163..ed3219c49b 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -376,7 +376,7 @@ class ExportTest(test_util.TensorFlowTestCase):
" } "
"} ", example)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_result = sess.run(
serving_input_receiver.features,
feed_dict={
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 7e5a0c80a7..3758243d7b 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -102,6 +102,49 @@ def gen_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=False):
return input_fn
+def get_multi_inputs_multi_outputs_data():
+ (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(16,),
+ num_classes=3,
+ random_seed=_RANDOM_SEED)
+ (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(16,),
+ num_classes=2,
+ random_seed=_RANDOM_SEED)
+ (m_train, _), (m_test, _) = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(8,),
+ num_classes=2,
+ random_seed=_RANDOM_SEED)
+
+ c_train = keras.utils.to_categorical(c_train)
+ c_test = keras.utils.to_categorical(c_test)
+ d_train = keras.utils.to_categorical(d_train)
+ d_test = keras.utils.to_categorical(d_test)
+
+ train_data = {
+ 'input_a': a_train,
+ 'input_b': b_train,
+ 'input_m': m_train,
+ 'output_c': c_train,
+ 'output_d': d_train
+ }
+ test_data = {
+ 'input_a': a_test,
+ 'input_b': b_test,
+ 'input_m': m_test,
+ 'output_c': c_test,
+ 'output_d': d_test
+ }
+
+ return (train_data, test_data)
+
+
def get_resource_for_simple_model(model_type='sequential',
is_evaluate=False,):
if model_type == 'sequential':
@@ -159,20 +202,21 @@ def randomize_io_type(array, name):
def multi_inputs_multi_outputs_model():
- a = keras.layers.Input(shape=(16,), name='input_a')
- b = keras.layers.Input(shape=(16,), name='input_b')
- m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
+ input_a = keras.layers.Input(shape=(16,), name='input_a')
+ input_b = keras.layers.Input(shape=(16,), name='input_b')
+ input_m = keras.layers.Input(shape=(8,), dtype='string', name='input_m')
dense = keras.layers.Dense(8, name='dense_1')
- a_2 = dense(a)
+ interm_a = dense(input_a)
# Read m
- m_2 = keras.layers.Lambda(gen_parsing_ops.string_to_number)(m)
- s_2 = keras.layers.Lambda(lambda k: k[0] * k[1])([m_2, a_2])
- b_2 = dense(b)
- merged = keras.layers.concatenate([s_2, b_2], name='merge')
- c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
- d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
- model = keras.models.Model(inputs=[a, b, m], outputs=[c, d])
+ interm_m = keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)
+ interm_s = keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])
+ interm_b = dense(input_b)
+ merged = keras.layers.concatenate([interm_s, interm_b], name='merge')
+ output_c = keras.layers.Dense(3, activation='softmax', name='dense_2')(merged)
+ output_d = keras.layers.Dense(2, activation='softmax', name='dense_3')(merged)
+ model = keras.models.Model(
+ inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])
model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
@@ -414,51 +458,85 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
]
self.assertAllEqual(est_pred, keras_pred)
- def test_multi_inputs_multi_outputs(self):
- np.random.seed(_RANDOM_SEED)
- (a_train, c_train), (a_test, c_test) = testing_utils.get_test_data(
- train_samples=_TRAIN_SIZE,
- test_samples=50,
- input_shape=(16,),
- num_classes=3)
- np.random.seed(_RANDOM_SEED)
- (b_train, d_train), (b_test, d_test) = testing_utils.get_test_data(
- train_samples=_TRAIN_SIZE,
- test_samples=50,
- input_shape=(16,),
- num_classes=2)
- np.random.seed(_RANDOM_SEED)
- (input_m_train, _), (input_m_test, _) = testing_utils.get_test_data(
- train_samples=_TRAIN_SIZE,
- test_samples=50,
- input_shape=(8,),
- num_classes=2)
-
- c_train = keras.utils.to_categorical(c_train)
- c_test = keras.utils.to_categorical(c_test)
- d_train = keras.utils.to_categorical(d_train)
- d_test = keras.utils.to_categorical(d_test)
+ def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self):
+ train_data, test_data = get_multi_inputs_multi_outputs_data()
def train_input_fn():
- input_dict = {'input_a': a_train, 'input_b': b_train,
- 'input_m': input_m_train.astype(np.str)}
- output_dict = {'dense_2': c_train, 'dense_3': d_train}
+ input_dict = {
+ 'input_a': train_data['input_a'],
+ 'input_b': train_data['input_b'],
+ 'input_m': train_data['input_m'].astype(np.str)
+ }
+ output_dict = {
+ 'dense_2': train_data['output_c'],
+ 'dense_3': train_data['output_d']
+ }
return input_dict, output_dict
def eval_input_fn():
- input_dict = {'input_a': a_test, 'input_b': b_test,
- 'input_m': input_m_test.astype(np.str)}
- output_dict = {'dense_2': c_test, 'dense_3': d_test}
+ input_dict = {
+ 'input_a': test_data['input_a'],
+ 'input_b': test_data['input_b'],
+ 'input_m': test_data['input_m'].astype(np.str)
+ }
+ output_dict = {
+ 'dense_2': test_data['output_c'],
+ 'dense_3': test_data['output_d']
+ }
return input_dict, output_dict
+ def pred_input_fn():
+ input_dict = {
+ 'input_a': test_data['input_a'],
+ 'input_b': test_data['input_b'],
+ 'input_m': test_data['input_m'].astype(np.str)
+ }
+ return input_dict
+
+ self.do_test_multi_inputs_multi_outputs_with_input_fn(
+ train_input_fn, eval_input_fn, pred_input_fn)
+
+ def test_multi_inputs_multi_outputs_with_input_fn_as_list(self):
+ train_data, test_data = get_multi_inputs_multi_outputs_data()
+
+ def train_input_fn():
+ input_list = [
+ train_data['input_a'], train_data['input_b'],
+ train_data['input_m'].astype(np.str)
+ ]
+ output_list = [train_data['output_c'], train_data['output_d']]
+ return input_list, output_list
+
+ def eval_input_fn():
+ input_list = [
+ test_data['input_a'], test_data['input_b'],
+ test_data['input_m'].astype(np.str)
+ ]
+ output_list = [test_data['output_c'], test_data['output_d']]
+ return input_list, output_list
+
+ def pred_input_fn():
+ input_list = [
+ test_data['input_a'], test_data['input_b'],
+ test_data['input_m'].astype(np.str)
+ ]
+ return input_list
+
+ self.do_test_multi_inputs_multi_outputs_with_input_fn(
+ train_input_fn, eval_input_fn, pred_input_fn)
+
+ def do_test_multi_inputs_multi_outputs_with_input_fn(
+ self, train_input_fn, eval_input_fn, pred_input_fn):
with self.cached_session():
model = multi_inputs_multi_outputs_model()
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
- before_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ baseline_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)
- after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
- self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+ eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(eval_results['loss'], baseline_eval_results['loss'])
+ est_keras.predict(input_fn=pred_input_fn)
def test_init_from_file(self):
if h5py is None:
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 439cc2e3a4..824789467d 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -308,6 +308,8 @@ class EstimatorSpec(
for key, value in six.iteritems(eval_metric_ops):
if isinstance(value, Metric):
vars_to_add.update(value.variables)
+ # Convert Metric instances to (value_tensor, update_op) tuple.
+ eval_metric_ops[key] = (value.result(), value.updates[0])
# Remove variables that are in the local variables collection already.
vars_to_add = vars_to_add.difference(local_vars)
for v in vars_to_add:
@@ -466,13 +468,13 @@ class _TPUEstimatorSpec(
def _check_is_tensor_or_operation(x, name):
- if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)):
+ if not (isinstance(x, ops.Operation) or ops.is_dense_tensor_like(x)):
raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))
def _check_is_tensor(x, tensor_name):
"""Returns `x` if it is a `Tensor`, raises TypeError otherwise."""
- if not isinstance(x, ops.Tensor):
+ if not ops.is_dense_tensor_like(x):
raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x))
return x
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index ac53a84eef..5800b693b4 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -156,7 +156,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:numpy_io",
+ "//tensorflow/python/estimator:estimator_py",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 2246d2f3e9..9984379e9d 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -169,7 +169,8 @@ def _internal_input_layer(features,
weight_collections=None,
trainable=True,
cols_to_vars=None,
- scope=None):
+ scope=None,
+ cols_to_output_tensors=None):
"""See input_layer. `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
@@ -202,14 +203,17 @@ def _internal_input_layer(features,
trainable=trainable)
num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
batch_size = array_ops.shape(tensor)[0]
- output_tensors.append(
- array_ops.reshape(tensor, shape=(batch_size, num_elements)))
+ output_tensor = array_ops.reshape(
+ tensor, shape=(batch_size, num_elements))
+ output_tensors.append(output_tensor)
if cols_to_vars is not None:
# Retrieve any variables created (some _DenseColumn's don't create
# variables, in which case an empty list is returned).
cols_to_vars[column] = ops.get_collection(
ops.GraphKeys.GLOBAL_VARIABLES,
scope=variable_scope.get_variable_scope().name)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = output_tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
@@ -219,7 +223,8 @@ def input_layer(features,
feature_columns,
weight_collections=None,
trainable=True,
- cols_to_vars=None):
+ cols_to_vars=None,
+ cols_to_output_tensors=None):
"""Returns a dense `Tensor` as input layer based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.
@@ -264,6 +269,9 @@ def input_layer(features,
dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
<tf.Variable 'some_variable:1' shape=(5, 10)]}
If a column creates no variables, its value will be an empty list.
+ cols_to_output_tensors: If not `None`, must be a dictionary that will be
+ filled with a mapping from '_FeatureColumn' to the associated
+ output `Tensor`s.
Returns:
A `Tensor` which represents input layer of a model. Its shape
@@ -273,8 +281,13 @@ def input_layer(features,
Raises:
ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
"""
- return _internal_input_layer(features, feature_columns, weight_collections,
- trainable, cols_to_vars)
+ return _internal_input_layer(
+ features,
+ feature_columns,
+ weight_collections=weight_collections,
+ trainable=trainable,
+ cols_to_vars=cols_to_vars,
+ cols_to_output_tensors=cols_to_output_tensors)
# TODO(akshayka): InputLayer should be a subclass of Layer, and it
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 9b482237ab..abb79efa68 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -1637,6 +1637,40 @@ class LinearModelTest(test.TestCase):
self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
+ def test_fills_cols_to_output_tensors(self):
+ # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
+ # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
+ # creates a Variable.
+ apple_numeric_column = fc.numeric_column('apple_numeric_column')
+ banana_dense_feature = fc.numeric_column('banana_dense_feature')
+ banana_dense_feature_bucketized = fc.bucketized_column(
+ banana_dense_feature, boundaries=[0.])
+ cherry_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'cherry_sparse_feature', hash_bucket_size=5)
+ dragonfruit_embedding_column = fc.embedding_column(
+ cherry_sparse_column, dimension=10)
+ with ops.Graph().as_default():
+ features = {
+ 'apple_numeric_column': [[3.], [4.]],
+ 'banana_dense_feature': [[-1.], [4.]],
+ 'cherry_sparse_feature': [['a'], ['x']],
+ }
+ cols_to_output_tensors = {}
+ all_cols = [
+ apple_numeric_column, banana_dense_feature_bucketized,
+ dragonfruit_embedding_column
+ ]
+ input_layer = fc.input_layer(
+ features, all_cols, cols_to_output_tensors=cols_to_output_tensors)
+
+ # We check the mapping by checking that we have the right keys,
+ # and that the values (output_tensors) were indeed the ones used to
+ # form the input layer.
+ self.assertItemsEqual(all_cols, cols_to_output_tensors.keys())
+ input_layer_inputs = [tensor for tensor in input_layer.op.inputs[:-1]]
+ output_tensors = [tensor for tensor in cols_to_output_tensors.values()]
+ self.assertItemsEqual(input_layer_inputs, output_tensors)
+
def test_dense_collection(self):
price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 28c5c82d2c..57f7af7635 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -3433,9 +3433,11 @@ def _safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ if not isinstance(embedding_weights[0],
+ resource_variable_ops.ResourceVariable):
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index a8aef3a009..f287289bd0 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -762,13 +762,12 @@ class _FuncGraph(ops.Graph):
if handle_data:
handle_data = handle_data.SerializeToString()
else:
- handle_data = c_api.GetResourceHandleShapeAndType(
- tensor.graph._c_graph, tensor._as_tf_output())
+ handle_data = c_api.GetHandleShapeAndType(tensor.graph._c_graph,
+ tensor._as_tf_output())
if handle_data:
- c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
- ph._as_tf_output(),
- compat.as_bytes(handle_data))
+ c_api.SetHandleShapeAndType(ph.graph._c_graph, ph._as_tf_output(),
+ compat.as_bytes(handle_data))
else:
ph._handle_data = tensor._handle_data
# pylint: enable=protected-access
@@ -1097,6 +1096,21 @@ def _from_library(lib):
return initialized.values()
+def _get_experimental_kwarg_as_attr(attr_name, value):
+ """Creates an AttrValue for a python object."""
+ if isinstance(value, bool):
+ return attr_value_pb2.AttrValue(b=value)
+ elif isinstance(value, int):
+ return attr_value_pb2.AttrValue(i=value)
+ elif isinstance(value, float):
+ return attr_value_pb2.AttrValue(f=value)
+ elif isinstance(value, str):
+ return attr_value_pb2.AttrValue(s=compat.as_bytes(value))
+ else:
+ raise ValueError("Unsupported attribute type for %s with type %s" %
+ (attr_name, type(value)))
+
+
def _parse_kwargs_as_attrs(func_name, **kwargs):
"""Parses **kwargs into a node's attributes."""
attrs = {}
@@ -1123,7 +1137,7 @@ def _parse_kwargs_as_attrs(func_name, **kwargs):
kwargs_keys = list(kwargs.keys())
for key in kwargs_keys:
if key.startswith("experimental_"):
- attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(kwargs[key]))
+ attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key])
del kwargs[key]
if kwargs:
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 903768a039..f740e5cfaa 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -1331,12 +1331,33 @@ class FunctionsFromProtos(test.TestCase):
def testExperimentalAttrs(self):
@function.Defun(dtypes.int32, experimental_tag="tag_value")
- def FunctionWithAttr(i):
+ def FunctionWithStrAttr(i):
return array_ops.identity(i)
- self.assertTrue("experimental_tag" in FunctionWithAttr.definition.attr)
- self.assertEqual(FunctionWithAttr.definition.attr["experimental_tag"].s,
+ @function.Defun(dtypes.int32, experimental_tag=123)
+ def FunctionWithIntAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=123.0)
+ def FunctionWithFloatAttr(i):
+ return array_ops.identity(i)
+
+ @function.Defun(dtypes.int32, experimental_tag=True)
+ def FunctionWithBoolAttr(i):
+ return array_ops.identity(i)
+
+ self.assertTrue("experimental_tag" in FunctionWithStrAttr.definition.attr)
+ self.assertEqual(FunctionWithStrAttr.definition.attr["experimental_tag"].s,
b"tag_value")
+ self.assertTrue("experimental_tag" in FunctionWithIntAttr.definition.attr)
+ self.assertEqual(FunctionWithIntAttr.definition.attr["experimental_tag"].i,
+ 123)
+ self.assertTrue("experimental_tag" in FunctionWithFloatAttr.definition.attr)
+ self.assertEqual(
+ FunctionWithFloatAttr.definition.attr["experimental_tag"].f, 123.0)
+ self.assertTrue("experimental_tag" in FunctionWithBoolAttr.definition.attr)
+ self.assertEqual(FunctionWithBoolAttr.definition.attr["experimental_tag"].b,
+ True)
@test_util.with_c_shapes
diff --git a/tensorflow/python/framework/load_library.py b/tensorflow/python/framework/load_library.py
index 535c6017f5..908a5f521e 100644
--- a/tensorflow/python/framework/load_library.py
+++ b/tensorflow/python/framework/load_library.py
@@ -18,14 +18,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import errno
import hashlib
import imp
+import os
+import platform
import sys
import threading # pylint: disable=unused-import
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.lib.core import error_codes_pb2 # pylint: disable=unused-import
from tensorflow.python import pywrap_tensorflow as py_tf
+from tensorflow.python.lib.io import file_io
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -98,3 +102,64 @@ def load_file_system_library(library_filename):
RuntimeError: when unable to load the library.
"""
py_tf.TF_LoadLibrary(library_filename)
+
+
+def _is_shared_object(filename):
+ """Check the file to see if it is a shared object, only using extension."""
+ if platform.system() == 'Linux':
+ if filename.endswith('.so'):
+ return True
+ else:
+ index = filename.rfind('.so.')
+ if index == -1:
+ return False
+ else:
+ # A shared object with the API version in filename
+ return filename[index + 4].isdecimal()
+ elif platform.system() == 'Darwin':
+ return filename.endswith('.dylib')
+ elif platform.system() == 'Windows':
+ return filename.endswith('.dll')
+ else:
+ return False
+
+
+@tf_export('load_library')
+def load_library(library_location):
+ """Loads a TensorFlow plugin.
+
+ "library_location" can be a path to a specific shared object, or a folder.
+ If it is a folder, all sahred objects that are named "libtfkernel*" will be
+ loaded. When the library is loaded, kernels registered in the library via the
+ `REGISTER_*` macros are made available in the TensorFlow process.
+
+ Args:
+ library_location: Path to the plugin or the folder of plugins.
+ Relative or absolute filesystem path to a dynamic library file or folder.
+
+ Returns:
+ None
+
+ Raises:
+ OSError: When the file to be loaded is not found.
+ RuntimeError: when unable to load the library.
+ """
+ if file_io.file_exists(library_location):
+ if file_io.is_directory(library_location):
+ directory_contents = file_io.list_directory(library_location)
+
+ kernel_libraries = [
+ os.path.join(library_location, f) for f in directory_contents
+ if _is_shared_object(f)]
+ else:
+ kernel_libraries = [library_location]
+
+ for lib in kernel_libraries:
+ py_tf.TF_LoadLibrary(lib)
+
+ else:
+ raise OSError(
+ errno.ENOENT,
+ 'The file or folder to load kernel libraries from does not exist.',
+ library_location)
+
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 75678cbc01..8bb177939e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils
from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import lock_util
+from tensorflow.python.util import memory
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_stack
from tensorflow.python.util.deprecation import deprecated_args
@@ -2531,8 +2532,8 @@ def _set_shape_and_handle_data_for_outputs_c_api(op):
output._shape_val = output._c_api_shape()
# Set the resource handle data for compatibility with the Python shape
# inference code.
- serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph,
- output._as_tf_output())
+ serialized = c_api.GetHandleShapeAndType(op._graph._c_graph, # pylint: disable=protected-access
+ output._as_tf_output())
if serialized:
output._handle_data = (
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
@@ -5824,23 +5825,11 @@ def dismantle_graph(graph):
graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
after this function runs.
"""
- # pylint: disable=protected-access
- # OrderedDict, constructed on Graph creation, makes a simple reference loop
- # and hides it in an __attribute in some Python versions. We don't need to
- # throw an error if we can't find it, but if we do find it we can break the
- # loop to avoid creating work for the garbage collector.
- graph_operations = graph.get_operations()
- problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
- # pylint: enable=protected-access
- if problematic_cycle:
- try:
- del problematic_cycle[0][:]
- except TypeError:
- # This is probably not one of the problematic Python versions. Continue
- # with the rest of our cleanup.
- pass
+ memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access
+
# Now clean up Operation<->Graph reference cycles by clearing all of the
# attributes for the Graph and its ops.
+ graph_operations = graph.get_operations()
for op in graph_operations:
op.__dict__ = {}
graph.__dict__ = {}
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index d59adf3d48..c3a3437743 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -2142,8 +2142,8 @@ class InitScopeTest(test_util.TensorFlowTestCase):
def function_with_variables():
with ops.init_scope():
- v = resource_variable_ops.ResourceVariable(3)
- return v.assign_add(1)
+ self.v = resource_variable_ops.ResourceVariable(3)
+ return self.v.assign_add(1)
with context.eager_mode():
# Each invocation of function_with_variables recreates a variable.
@@ -2188,13 +2188,13 @@ class InitScopeTest(test_util.TensorFlowTestCase):
def inner_function():
with ops.init_scope():
- v = resource_variable_ops.ResourceVariable(1)
- return v.assign_add(2)
+ self.v = resource_variable_ops.ResourceVariable(1)
+ return self.v.assign_add(2)
def outer_function(inner=None):
with ops.init_scope():
- v0 = resource_variable_ops.ResourceVariable(0)
- return v0.assign_add(1) + inner()
+ self.v0 = resource_variable_ops.ResourceVariable(0)
+ return self.v0.assign_add(1) + inner()
with context.eager_mode():
# Each invocation of outer_function recreates variables.
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index d63abd7f01..cd0b03be43 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,6 +24,7 @@ from collections import OrderedDict
import contextlib
import gc
import itertools
+import os
import math
import random
import re
@@ -69,6 +70,7 @@ from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+from tensorflow.python.util import memory
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare
@@ -413,15 +415,13 @@ def enable_cond_v2(fn):
The wrapped function
"""
- # pylint: disable=protected-access
def wrapper(*args, **kwargs):
- prev_value = control_flow_ops._ENABLE_COND_V2
- control_flow_ops._ENABLE_COND_V2 = True
+ prev_value = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
try:
fn(*args, **kwargs)
finally:
- control_flow_ops._ENABLE_COND_V2 = prev_value
- # pylint: enable=protected-access
+ control_flow_ops.ENABLE_COND_V2 = prev_value
return wrapper
@@ -438,7 +438,7 @@ def with_cond_v2(cls):
Returns:
cls with new test methods added
"""
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return cls
for name, value in cls.__dict__.copy().items():
@@ -780,7 +780,7 @@ def run_in_graph_and_eager_modes(func=None,
def run_eagerly(self, **kwargs):
if not use_gpu:
- with ops.device("/cpu:0"):
+ with ops.device("/device:CPU:0"):
f(self, **kwargs)
else:
f(self, **kwargs)
@@ -869,6 +869,19 @@ def device(use_gpu):
yield
+class CapturedWrites(object):
+ """A utility class to load the captured writes made to a stream."""
+
+ def __init__(self, capture_location):
+ self.capture_location = capture_location
+
+ def contents(self):
+ """Get the captured writes as a single string."""
+ with open(self.capture_location) as tmp_file:
+ output_data = "".join(tmp_file.readlines())
+ return output_data
+
+
class ErrorLoggingSession(session.Session):
"""Wrapper around a Session that logs errors in run().
"""
@@ -877,7 +890,11 @@ class ErrorLoggingSession(session.Session):
try:
return super(ErrorLoggingSession, self).run(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
- logging.error(str(e))
+ # Note: disable the logging for OutOfRangeError, which makes the output
+ # of tf.data tests hard to read, because OutOfRangeError is used as the
+ # signal completion
+ if not isinstance(e, errors.OutOfRangeError):
+ logging.error(str(e))
raise
@@ -935,6 +952,52 @@ class TensorFlowTestCase(googletest.TestCase):
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
+ @contextlib.contextmanager
+ def captureWritesToStream(self, stream):
+ """A context manager that captures the writes to a given stream.
+
+ This context manager captures all writes to a given stream inside of a
+ `CapturedWrites` object. When this context manager is created, it yields
+ the `CapturedWrites` object. The captured contents can be accessed by
+ calling `.contents()` on the `CapturedWrites`.
+
+ For this function to work, the stream must have a file descriptor that
+ can be modified using `os.dup` and `os.dup2`, and the stream must support
+ a `.flush()` method. The default python sys.stdout and sys.stderr are
+ examples of this. Note that this does not work in Colab or Jupyter
+ notebooks, because those use alternate stdout streams.
+
+ Example:
+ ```python
+ class MyOperatorTest(test_util.TensorFlowTestCase):
+ def testMyOperator(self):
+ input = [1.0, 2.0, 3.0, 4.0, 5.0]
+ with self.captureWritesToStream(sys.stdout) as captured:
+ result = MyOperator(input).eval()
+ self.assertStartsWith(captured.contents(), "This was printed.")
+ ```
+
+ Args:
+ stream: The stream whose writes should be captured. This
+ stream must have a file descriptor, support writing via using that
+ file descriptor, and must have a `.flush()` method.
+
+ Yields:
+ A `CapturedWrites` object that contains all writes to the specified stream
+ made during this context.
+ """
+ stream.flush()
+ fd = stream.fileno()
+ tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
+ tmp_file = open(tmp_file_path, "w")
+ orig_fd = os.dup(fd)
+ os.dup2(tmp_file.fileno(), fd)
+ try:
+ yield CapturedWrites(tmp_file_path)
+ finally:
+ tmp_file.close()
+ os.dup2(orig_fd, fd)
+
def _AssertProtoEquals(self, a, b, msg=None):
"""Asserts that a and b are the same proto.
@@ -1338,35 +1401,36 @@ class TensorFlowTestCase(googletest.TestCase):
b.shape)
self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
+ msgs = [msg]
if not np.allclose(a, b, rtol=rtol, atol=atol):
- # Prints more details than np.testing.assert_allclose.
+ # Adds more details to np.testing.assert_allclose.
#
# NOTE: numpy.allclose (and numpy.testing.assert_allclose)
# checks whether two arrays are element-wise equal within a
# tolerance. The relative difference (rtol * abs(b)) and the
# absolute difference atol are added together to compare against
# the absolute difference between a and b. Here, we want to
- # print out which elements violate such conditions.
+ # tell user which elements violate such conditions.
cond = np.logical_or(
np.abs(a - b) > atol + rtol * np.abs(b),
np.isnan(a) != np.isnan(b))
if a.ndim:
x = a[np.where(cond)]
y = b[np.where(cond)]
- print("not close where = ", np.where(cond))
+ msgs.append("not close where = {}".format(np.where(cond)))
else:
# np.where is broken for scalars
x, y = a, b
- print("not close lhs = ", x)
- print("not close rhs = ", y)
- print("not close dif = ", np.abs(x - y))
- print("not close tol = ", atol + rtol * np.abs(y))
- print("dtype = %s, shape = %s" % (a.dtype, a.shape))
+ msgs.append("not close lhs = {}".format(x))
+ msgs.append("not close rhs = {}".format(y))
+ msgs.append("not close dif = {}".format(np.abs(x - y)))
+ msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
+ msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
# TODO(xpan): There seems to be a bug:
# tensorflow/compiler/tests:binary_ops_test pass with float32
# nan even though the equal_nan is False by default internally.
np.testing.assert_allclose(
- a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True)
+ a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
def _assertAllCloseRecursive(self,
a,
@@ -1548,19 +1612,20 @@ class TensorFlowTestCase(googletest.TestCase):
np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
]):
same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
+ msgs = [msg]
if not np.all(same):
- # Prints more details than np.testing.assert_array_equal.
+ # Adds more details to np.testing.assert_array_equal.
diff = np.logical_not(same)
if a.ndim:
x = a[np.where(diff)]
y = b[np.where(diff)]
- print("not equal where = ", np.where(diff))
+ msgs.append("not equal where = {}".format(np.where(diff)))
else:
# np.where is broken for scalars
x, y = a, b
- print("not equal lhs = ", x)
- print("not equal rhs = ", y)
- np.testing.assert_array_equal(a, b, err_msg=msg)
+ msgs.append("not equal lhs = {}".format(x))
+ msgs.append("not equal rhs = {}".format(y))
+ np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
def assertAllGreater(self, a, comparison_target):
"""Assert element values are all greater than a target value.
@@ -1663,7 +1728,7 @@ class TensorFlowTestCase(googletest.TestCase):
if any of the elements do not fall in the specified range.
"""
target = self._GetNdArray(target)
- if not (np.issubdtype(target.dtype, np.float) or
+ if not (np.issubdtype(target.dtype, np.floating) or
np.issubdtype(target.dtype, np.integer)):
raise AssertionError(
"The value of %s does not have an ordered numeric type, instead it "
@@ -1840,7 +1905,7 @@ class TensorFlowTestCase(googletest.TestCase):
elif use_gpu:
yield sess
else:
- with sess.graph.device("/cpu:0"):
+ with sess.graph.device("/device:CPU:0"):
yield sess
def _create_session(self, graph, config, force_gpu):
@@ -1855,12 +1920,18 @@ class TensorFlowTestCase(googletest.TestCase):
Returns:
A config_pb2.ConfigProto object.
"""
+ # TODO(b/114333779): Enforce allow_soft_placement=False when
+ # use_gpu=False. Currently many tests rely on the fact that any device
+ # will be used even when a specific device is supposed to be used.
+ allow_soft_placement = not force_gpu
if config is None:
config = config_pb2.ConfigProto()
- config.allow_soft_placement = not force_gpu
+ config.allow_soft_placement = allow_soft_placement
config.gpu_options.per_process_gpu_memory_fraction = 0.3
- elif force_gpu and config.allow_soft_placement:
- config = config_pb2.ConfigProto().CopyFrom(config)
+ elif not allow_soft_placement and config.allow_soft_placement:
+ config_copy = config_pb2.ConfigProto()
+ config_copy.CopyFrom(config)
+ config = config_copy
config.allow_soft_placement = False
# Don't perform optimizations for tests so we don't inadvertently run
# gpu ops on cpu
@@ -1869,6 +1940,8 @@ class TensorFlowTestCase(googletest.TestCase):
rewriter_config_pb2.RewriterConfig.OFF)
config.graph_options.rewrite_options.arithmetic_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
+ config.graph_options.rewrite_options.pin_to_host_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
return config
return ErrorLoggingSession(graph=graph, config=prepare_config(config))
@@ -2010,3 +2083,42 @@ def set_producer_version(graph, producer_version):
with graph.as_default():
importer.import_graph_def(graph_def)
assert graph.graph_def_versions.producer, producer_version
+
+
+def dismantle_func_graph(func_graph):
+ """Removes reference cycles in `func_graph` FuncGraph.
+
+ Helpful for making sure the garbage collector doesn't need to run when
+ the FuncGraph goes out of scope, e.g. in tests using defun with
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+ Args:
+ func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
+ after this function.
+ """
+ # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
+ # Clearing captures using clear() leaves some cycles around.
+ while func_graph.captures:
+ func_graph.captures.popitem()
+ memory.dismantle_ordered_dict(func_graph.captures)
+ ops.dismantle_graph(func_graph)
+
+
+def dismantle_polymorphic_function(func):
+ """Removes reference cycles in PolymorphicFunction `func`.
+
+ Helpful for making sure the garbage collector doesn't need to run when
+ PolymorphicFunction goes out of scope, e.g. in tests using defun with
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+ Args:
+ func: A `PolymorphicFunction` object to destroy. `func` is unusable
+ after this function.
+ """
+ # TODO(b/115366440): Delete this method when a custom OrderedDict is added
+ cache = func._function_cache # pylint: disable=protected-access
+ for concrete_func in cache.values():
+ dismantle_func_graph(concrete_func.graph)
+ while cache:
+ cache.popitem()
+ memory.dismantle_ordered_dict(cache)
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index c4f8fa9108..22189afa59 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -268,6 +268,11 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllClose(7, 7 + 1e-5)
@test_util.run_in_graph_and_eager_modes
+ def testAllCloseList(self):
+ with self.assertRaisesRegexp(AssertionError, r"not close dif"):
+ self.assertAllClose([0], [1])
+
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseDictToNonDict(self):
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose(1, {"a": 1})
@@ -452,6 +457,9 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllEqual([120] * 3, k)
self.assertAllEqual([20] * 3, j)
+ with self.assertRaisesRegexp(AssertionError, r"not equal lhs"):
+ self.assertAllEqual([0] * 3, k)
+
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllClose(self):
# Test with arrays
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 290e182a79..ac011a2940 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -7,6 +7,7 @@ exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
+load("@pip_deps//:requirements.bzl", "requirement")
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
@@ -62,6 +63,7 @@ py_library(
":backend",
":engine",
":layers",
+ requirement("keras_applications"),
"//tensorflow/python/saved_model",
"//tensorflow/python:training",
],
@@ -337,11 +339,6 @@ py_test(
size = "large",
srcs = ["layers/convolutional_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "manual",
- "noasan", # times out b/63678675
- "notsan",
- ],
deps = [
":keras",
"//tensorflow/python:client_testlib",
@@ -386,12 +383,11 @@ py_test(
],
)
-py_test(
+cuda_py_test(
name = "embeddings_test",
size = "medium",
srcs = ["layers/embeddings_test.py"],
- srcs_version = "PY2AND3",
- deps = [
+ additional_deps = [
":keras",
"//tensorflow/python:client_testlib",
],
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index a8b6d55e41..c35cdb15a4 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -63,7 +63,8 @@ def keras_modules_injection(base_fun):
def wrapper(*args, **kwargs):
if hasattr(keras_applications, 'get_submodules_from_kwargs'):
kwargs['backend'] = backend
- kwargs['layers'] = layers
+ if 'layers' not in kwargs:
+ kwargs['layers'] = layers
kwargs['models'] = models
kwargs['utils'] = utils
return base_fun(*args, **kwargs)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 7768caeaf0..4589c821e5 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -73,7 +73,16 @@ _SESSION = None
# This dictionary holds a mapping {graph: learning_phase}.
# A learning phase is a bool tensor used to run Keras models in
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
-_GRAPH_LEARNING_PHASES = {}
+_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
+
+
+# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
+# We keep a separate reference to it to make sure it does not get removed from
+# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
+# string because strings are not weakly-referencable.
+class _DummyEagerGraph(object):
+ pass
+_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
# This boolean flag can be set to True to leave variable initialization
# up to the user.
@@ -96,11 +105,11 @@ _LOCAL_DEVICES = None
# This dictionary holds a mapping between a graph and variables to initialize
# in the graph.
-_GRAPH_VARIABLES = {}
+_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
# This dictionary holds a mapping between a graph and TF optimizers created in
# the graph.
-_GRAPH_TF_OPTIMIZERS = {}
+_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
@tf_export('keras.backend.backend')
@@ -358,18 +367,26 @@ def learning_phase():
Returns:
Learning phase (scalar integer tensor or Python integer).
"""
- if context.executing_eagerly():
- if 'eager' not in _GRAPH_LEARNING_PHASES:
- # Fallback to inference mode as default.
- return 0
- return _GRAPH_LEARNING_PHASES['eager']
+ with ops.init_scope():
+ # We always check & set the learning phase inside the init_scope,
+ # otherwise the wrong default_graph will be used to look up the learning
+ # phase inside of functions & defuns.
+ #
+ # This is because functions & defuns (both in graph & in eager mode)
+ # will always execute non-eagerly using a function-specific default
+ # subgraph.
+ if context.executing_eagerly():
+ if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
+ # Fallback to inference mode as default.
+ return 0
+ return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
- graph = ops.get_default_graph()
- if graph not in _GRAPH_LEARNING_PHASES:
- phase = array_ops.placeholder_with_default(
- False, shape=(), name='keras_learning_phase')
- _GRAPH_LEARNING_PHASES[graph] = phase
- return _GRAPH_LEARNING_PHASES[graph]
+ graph = ops.get_default_graph()
+ if graph not in _GRAPH_LEARNING_PHASES:
+ phase = array_ops.placeholder_with_default(
+ False, shape=(), name='keras_learning_phase')
+ _GRAPH_LEARNING_PHASES[graph] = phase
+ return _GRAPH_LEARNING_PHASES[graph]
@tf_export('keras.backend.set_learning_phase')
@@ -385,10 +402,11 @@ def set_learning_phase(value):
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
if value not in {0, 1}:
raise ValueError('Expected learning phase to be 0 or 1.')
- if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES['eager'] = value
- else:
- _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
+ with ops.init_scope():
+ if context.executing_eagerly():
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
+ else:
+ _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
@tf_contextlib.contextmanager
@@ -414,10 +432,11 @@ def learning_phase_scope(value):
yield value
finally:
# Restore learning phase to initial value.
- if context.executing_eagerly():
- _GRAPH_LEARNING_PHASES['eager'] = previous_value
- else:
- _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
+ with ops.init_scope():
+ if context.executing_eagerly():
+ _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
+ else:
+ _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
@tf_export('keras.backend.get_session')
@@ -676,10 +695,8 @@ def track_tf_optimizer(tf_optimizer):
if context.executing_eagerly():
return
graph = ops.get_default_graph()
- if graph not in _GRAPH_TF_OPTIMIZERS:
- _GRAPH_TF_OPTIMIZERS[graph] = set()
- _GRAPH_TF_OPTIMIZERS[graph].add(tf_optimizer)
-
+ optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
+ optimizers.add(tf_optimizer)
def track_variable(v):
"""Tracks the given variable for initialization."""
@@ -687,14 +704,14 @@ def track_variable(v):
return
graph = v.graph if hasattr(v, 'graph') else ops.get_default_graph()
if graph not in _GRAPH_VARIABLES:
- _GRAPH_VARIABLES[graph] = set()
+ _GRAPH_VARIABLES[graph] = weakref.WeakSet()
_GRAPH_VARIABLES[graph].add(v)
def _get_variables(graph=None):
"""Returns variables corresponding to the given graph for initialization."""
assert not context.executing_eagerly()
- variables = _GRAPH_VARIABLES.get(graph, set())
+ variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
variables.update(opt.optimizer.variables())
return variables
@@ -3450,14 +3467,18 @@ def relu(x, alpha=0., max_value=None, threshold=0):
Returns:
A tensor.
"""
- clip_max = max_value is not None
if alpha != 0.:
+ if max_value is None and threshold == 0:
+ return nn.leaky_relu(x, alpha=alpha)
+
if threshold != 0:
negative_part = nn.relu(-x + threshold)
else:
negative_part = nn.relu(-x)
+ clip_max = max_value is not None
+
if threshold != 0:
# computes x for x > threshold else 0
x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 266af56611..ab71589940 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -279,7 +279,7 @@ class BackendUtilsTest(test.TestCase):
keras.backend.get_session().run(fetches=[x, y]), [30., 40.])
def test_function_tf_run_options_with_run_metadata(self):
- with self.test_session():
+ with self.cached_session():
x_placeholder = keras.backend.placeholder(shape=())
y_placeholder = keras.backend.placeholder(shape=())
@@ -522,8 +522,9 @@ class BackendLinearAlgebraTest(test.TestCase):
relu_op = keras.backend.relu(x)
self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
- # alpha
+ # alpha (leaky relu used)
relu_op = keras.backend.relu(x, alpha=0.5)
+ self.assertTrue('LeakyRelu' in relu_op.name)
self.assertAllClose(keras.backend.eval(relu_op), [[-2, 0], [2, 7]])
# max_value < some elements
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index befe82f4ec..6dfbbf3694 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -360,7 +360,10 @@ class BaseLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
- self.seen += batch_size
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
+ self.seen += batch_size * num_steps
for k, v in logs.items():
if k in self.stateful_metrics:
@@ -448,10 +451,13 @@ class ProgbarLogger(Callback):
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
+ # In case of distribution strategy we can potentially run multiple steps
+ # at the same time, we should account for that in the `seen` calculation.
+ num_steps = logs.get('num_steps', 1)
if self.use_steps:
- self.seen += 1
+ self.seen += num_steps
else:
- self.seen += batch_size
+ self.seen += batch_size * num_steps
for k in self.params['metrics']:
if k in logs:
@@ -1068,7 +1074,7 @@ class TensorBoard(Callback):
logs = logs or {}
batch_logs = {('batch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(self._total_batches_seen, batch_logs)
self._total_batches_seen += 1
@@ -1092,7 +1098,7 @@ class TensorBoard(Callback):
# batch number as Tensorboard summaries
logs = {('epoch_' + k): v
for k, v in logs.items()
- if k not in ['batch', 'size']}
+ if k not in ['batch', 'size', 'num_steps']}
self._write_custom_summaries(epoch, logs)
# pop the histogram summary op after each epoch
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 7675a6586f..b6fae19823 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -63,7 +63,7 @@ class KerasCallbacksTest(test.TestCase):
if h5py is None:
return # Skip test if models cannot be saved.
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
temp_dir = self.get_temp_dir()
@@ -226,7 +226,7 @@ class KerasCallbacksTest(test.TestCase):
mode='unknown')
def test_EarlyStopping(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(123)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -265,7 +265,7 @@ class KerasCallbacksTest(test.TestCase):
verbose=0)
def test_EarlyStopping_reuse(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
patience = 3
data = np.random.random((100, 1))
@@ -287,7 +287,7 @@ class KerasCallbacksTest(test.TestCase):
assert len(hist.epoch) >= patience
def test_EarlyStopping_with_baseline(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
baseline = 0.5
(data, labels), _ = testing_utils.get_test_data(
@@ -321,7 +321,7 @@ class KerasCallbacksTest(test.TestCase):
monitor.on_epoch_end(0, logs={'loss': 0.})
def test_LearningRateScheduler(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -368,7 +368,7 @@ class KerasCallbacksTest(test.TestCase):
model.optimizer.lr)) - 0.01 / 4) < keras.backend.epsilon()
def test_ReduceLROnPlateau(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -470,7 +470,7 @@ class KerasCallbacksTest(test.TestCase):
self.assertEqual(reduce_on_plateau.min_delta, 1e-13)
def test_CSVLogger(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
@@ -549,7 +549,7 @@ class KerasCallbacksTest(test.TestCase):
tmpdir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
- with self.test_session():
+ with self.cached_session():
fp = os.path.join(tmpdir, 'test.csv')
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -601,7 +601,7 @@ class KerasCallbacksTest(test.TestCase):
assert 'nan' in values[-1], 'The last epoch was not logged.'
def test_TerminateOnNaN(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -666,7 +666,7 @@ class KerasCallbacksTest(test.TestCase):
i %= max_batch_index
# case: Sequential
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Dense(
@@ -743,7 +743,7 @@ class KerasCallbacksTest(test.TestCase):
tmpdir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
- with self.test_session():
+ with self.cached_session():
filepath = os.path.join(tmpdir, 'logs')
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -815,7 +815,7 @@ class KerasCallbacksTest(test.TestCase):
tmpdir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
- with self.test_session():
+ with self.cached_session():
filepath = os.path.join(tmpdir, 'logs')
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
@@ -925,7 +925,7 @@ class KerasCallbacksTest(test.TestCase):
y_test = keras.utils.to_categorical(y_test)
y_train = keras.utils.to_categorical(y_train)
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.Dense(
@@ -969,7 +969,7 @@ class KerasCallbacksTest(test.TestCase):
while True:
yield x, y
- with self.test_session():
+ with self.cached_session():
model = testing_utils.get_small_sequential_mlp(
num_hidden=10, num_classes=10, input_dim=100)
model.compile(
@@ -1011,7 +1011,7 @@ class KerasCallbacksTest(test.TestCase):
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
def test_LambdaCallback(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
@@ -1055,7 +1055,7 @@ class KerasCallbacksTest(test.TestCase):
assert not t.is_alive()
def test_TensorBoard_with_ReduceLROnPlateau(self):
- with self.test_session():
+ with self.cached_session():
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
@@ -1194,7 +1194,7 @@ class KerasCallbacksTest(test.TestCase):
def test_RemoteMonitorWithJsonPayload(self):
if requests is None:
self.skipTest('`requests` required to run this test')
- with self.test_session():
+ with self.cached_session():
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
test_samples=TEST_SAMPLES,
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index b28df75493..39341a931b 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_module
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
@@ -293,12 +294,14 @@ def configure_and_create_session(distribution_strategy):
K.set_session(session)
-def validate_inputs(x, y):
+def validate_inputs(x, y, distribution_strategy):
"""Validate inputs when using DistributionStrategy.
Args:
x: Model Inputs.
y: Model Targets.
+ distribution_strategy: The DistributionStrategy with which the model is
+ compiled.
Raises:
ValueError: if input is not a Dataset or a numpy array.
@@ -319,6 +322,17 @@ def validate_inputs(x, y):
'Iterator. You must pass a Dataset object or a numpy '
'array as input.')
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ for i in [x, y]:
+ if isinstance(i, dataset_ops.Dataset):
+ shapes = nest.flatten(i.output_shapes)
+ if any([not s.is_fully_defined() for s in shapes]):
+ raise ValueError(
+ 'Using TPUs currently requires fully defined shapes. Either use '
+ 'set_shape() on the input tensors or use '
+ 'dataset.batch(..., drop_remainder=True).'
+ 'Found unknown shape {} in input {}.'.format(s, i))
+
def get_input_batch_params(first_x_value, batch_size, current_strategy):
"""Calculate the number of batches and steps/steps_per_epoch.
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 148dd23be7..02d99d5d69 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -370,6 +370,13 @@ class TestWholeModelSaving(test.TestCase):
y = np.random.random((1, 3, 3))
model.train_on_batch(x, y)
new_model.train_on_batch(x, y)
+
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ eval_out = model.evaluate(x, y)
+ eval_out2 = new_model.evaluate(x, y)
+ self.assertArrayNear(eval_out, eval_out2, 0.001)
+
out = model.predict(x)
out2 = new_model.predict(x)
self.assertAllClose(out, out2, atol=1e-05)
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index 061db8ee34..a0da96334b 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -915,7 +915,7 @@ class TopologyConstructionTest(test.TestCase):
def test_constant_initializer_with_numpy(self):
- with self.test_session():
+ with self.cached_session():
initializer = keras.initializers.Constant(np.ones((3, 2)))
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,),
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 49b25e307e..ade8a4b32d 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -41,6 +41,7 @@ from tensorflow.python.keras.engine import training_eager
from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine.network import Network
+from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import weights_broadcast_ops
@@ -144,32 +145,34 @@ class Model(Network):
if i not in skip_target_weighing_indices
]
- def _get_metric_name(self, metric, output_index, weighted=False):
- """Returns the metric name corresponding to the given metric input.
+ def _cache_output_metric_attributes(self, metrics, weighted_metrics):
+ """Caches metric name and function attributes for every model output."""
+ output_shapes = [
+ None if output is None else output.get_shape().as_list()
+ for output in self.outputs
+ ]
+ self._per_output_metrics = training_utils.collect_per_output_metric_info(
+ metrics, self.output_names, output_shapes, self.loss_functions)
+ self._per_output_weighted_metrics = \
+ training_utils.collect_per_output_metric_info(
+ weighted_metrics, self.output_names, output_shapes,
+ self.loss_functions, self.sample_weights)
+
+ def _add_unique_metric_name(self, metric_name, output_index):
+ """Makes the metric name unique and adds it to the model's metric name list.
+
+ If there are multiple outputs for which the metrics are calculated, the
+ metric names have to be made unique by appending an integer.
Arguments:
- metric: Metric function name or reference.
- output_index: Index of the current output.
- weighted: Boolean indicating if the given metric is weighted.
+ metric_name: Metric name that corresponds to the metric specified by the
+ user. For example: 'acc'.
+ output_index: The index of the model output for which the metric name is
+ being added.
Returns:
- A metric name.
+ string, name of the model's unique metric name
"""
- metric_name_prefix = 'weighted_' if weighted else ''
- if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
- if metric in ('accuracy', 'acc'):
- suffix = 'acc'
- elif metric in ('crossentropy', 'ce'):
- suffix = 'ce'
- else:
- metric_fn = metrics_module.get(metric)
- # Get metric name as string
- if hasattr(metric_fn, 'name'):
- suffix = metric_fn.name
- else:
- suffix = metric_fn.__name__
- metric_name = metric_name_prefix + suffix
-
if len(self.output_names) > 1:
metric_name = '%s_%s' % (self.output_names[output_index], metric_name)
j = 1
@@ -180,24 +183,54 @@ class Model(Network):
return metric_name
+ def _init_metric_attributes(self):
+ """Initialized model metric attributes."""
+ self.metrics_names = ['loss']
+ self.metrics_tensors = []
+ self.metrics_updates = []
+ self.stateful_metric_names = []
+ self.stateful_metric_functions = []
+
+ def _set_per_output_metric_attributes(self, metrics_dict, output_index):
+ """Sets the metric attributes on the model for the given output.
+
+ Arguments:
+ metrics_dict: A dict with metric names as keys and metric fns as values.
+ output_index: The index of the model output for which the metric
+ attributes are added.
+ """
+ for metric_name, metric_fn in metrics_dict.items():
+ metric_name = self._add_unique_metric_name(metric_name, output_index)
+ # Keep track of metric name.
+ self.metrics_names.append(metric_name)
+
+ # Keep track of stateful metric attributes (name and metric function).
+ if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
+ self.stateful_metric_names.append(metric_name)
+ self.stateful_metric_functions.append(metric_fn)
+
+ def _set_metric_attributes(self, outputs, skip_target_indices=None):
+ """Sets the metric attributes on the model for all the model outputs."""
+ skip_target_indices = skip_target_indices or []
+ for i in range(len(outputs)):
+ if i in skip_target_indices:
+ continue
+ self._set_per_output_metric_attributes(self._per_output_metrics[i], i)
+ self._set_per_output_metric_attributes(
+ self._per_output_weighted_metrics[i], i)
+
def _handle_per_output_metrics(self,
- metrics,
+ metrics_dict,
y_true,
y_pred,
- output_index,
- output_shape,
- loss_fn,
mask,
weights=None):
- """Calls metric functions and sets metric attributes for a single output.
+ """Calls metric functions for a single output.
Arguments:
- metrics: List of metrics.
+ metrics_dict: A dict with metric names as keys and metric fns as values.
y_true: Target output.
y_pred: Predicted output.
- output_index: Index of the current output.
- output_shape: Shape of the current output.
- loss_fn: Loss function corresponding to the current output.
mask: Computed mask value for the current output.
weights: Weights to be applied on the current output.
@@ -205,60 +238,47 @@ class Model(Network):
A list of metric result tensors.
"""
metric_results = []
- for metric in metrics:
- metric_fn = training_utils.get_metric_function(
- metric, output_shape=output_shape, loss_fn=loss_fn)
- metric_name = self._get_metric_name(
- metric, output_index, weighted=weights is not None)
-
+ for metric_name, metric_fn in metrics_dict.items():
with K.name_scope(metric_name):
- # If both outputs and targets are available, call the metric function.
- if y_true is not None and y_pred is not None:
- if isinstance(metric_fn, metrics_module.Metric):
- # Call the stateful metric function.
- if mask is not None:
- mask = math_ops.cast(mask, y_pred.dtype)
- # Update weights with mask.
- if weights is None:
- weights = mask
- else:
- # Update shape of weights if possible before adding mask.
- # Update dimensions of weights to match with mask if possible.
- mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
- mask, None, weights)
- try:
- # Broadcast weights if possible.
- weights = weights_broadcast_ops.broadcast_weights(
- weights, mask)
- except ValueError:
- pass
- # TODO(psv): Handle case when mask and weight shapes are not
- # compatible.
- weights *= mask
-
- metric_result = metric_fn(y_true, y_pred, weights)
- else:
- # Call the stateless metric function.
- weighted_metric_fn = training_utils.weighted_masked_objective(
- metric_fn)
- metric_result = weighted_metric_fn(
- y_true, y_pred, weights=weights, mask=mask)
-
- if not context.executing_eagerly():
- # Keep track of metric result tensor.
- self.metrics_tensors.append(metric_result)
- metric_results.append(metric_result)
-
- # Keep track of metric name.
- self.metrics_names.append(metric_name)
+ if isinstance(metric_fn, metrics_module.Metric):
+ # Call the stateful metric function.
+ if mask is not None:
+ mask = math_ops.cast(mask, y_pred.dtype)
+ # Update weights with mask.
+ if weights is None:
+ weights = mask
+ else:
+ # Update shape of weights if possible before adding mask.
+ # Update dimensions of weights to match with mask if possible.
+ mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
+ mask, None, weights)
+ try:
+ # Broadcast weights if possible.
+ weights = weights_broadcast_ops.broadcast_weights(weights, mask)
+ except ValueError:
+ pass
+ # TODO(psv): Handle case when mask and weight shapes are not
+ # compatible.
+ weights *= mask
+
+ metric_result = metric_fn(y_true, y_pred, weights)
+ else:
+ # Call the stateless metric function.
+ weighted_metric_fn = training_utils.weighted_masked_objective(
+ metric_fn)
+ metric_result = weighted_metric_fn(
+ y_true, y_pred, weights=weights, mask=mask)
- # Keep track of stateful metric attributes (name and metric function).
- if isinstance(metric_fn, base_layer.Layer) and metric_fn.stateful:
- self.stateful_metric_names.append(metric_name)
- self.stateful_metric_functions.append(metric_fn)
if not context.executing_eagerly():
- # Keep track of updates created by stateful metrics.
- self.metrics_updates += metric_fn.updates
+ # Keep track of metric result tensor.
+ self.metrics_tensors.append(metric_result)
+
+ metric_results.append(metric_result)
+ is_stateful = isinstance(metric_fn,
+ base_layer.Layer) and metric_fn.stateful
+ if is_stateful and not context.executing_eagerly():
+ # Keep track of updates created by stateful metrics.
+ self.metrics_updates += metric_fn.updates
return metric_results
def _handle_metrics(self,
@@ -267,7 +287,7 @@ class Model(Network):
targets=None,
sample_weights=None,
masks=None):
- """Handles calling metric functions and setting model metric attributes.
+ """Handles calling metric functions.
Arguments:
outputs: List of outputs (predictions).
@@ -287,20 +307,15 @@ class Model(Network):
continue
output = outputs[i] if outputs else None
target = targets[i] if targets else None
- output_shape = None if output is None else output.get_shape().as_list()
output_mask = masks[i] if masks else None
metric_results.extend(
- self._handle_per_output_metrics(
- self.nested_metrics[i], target, output, i, output_shape,
- self.loss_functions[i], output_mask))
+ self._handle_per_output_metrics(self._per_output_metrics[i], target,
+ output, output_mask))
metric_results.extend(
self._handle_per_output_metrics(
- self.nested_weighted_metrics[i],
+ self._per_output_weighted_metrics[i],
target,
output,
- i,
- output_shape,
- self.loss_functions[i],
output_mask,
weights=sample_weights[i]))
return metric_results
@@ -368,27 +383,31 @@ class Model(Network):
"""
# Validate that arguments passed by the user to `compile` are supported by
# DistributionStrategy.
- if distribute and not isinstance(
- optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
- raise NotImplementedError('Only TF native optimizers are supported with '
- 'DistributionStrategy.')
- if distribute and context.executing_eagerly():
- raise NotImplementedError('DistributionStrategy is not supported in '
- 'Eager mode.')
- if distribute and sample_weight_mode:
- raise NotImplementedError('sample_weight_mode is not supported with '
- 'DistributionStrategy.')
- if distribute and weighted_metrics:
- raise NotImplementedError('weighted_metrics is not supported with '
- 'DistributionStrategy.')
- if distribute and target_tensors:
- raise ValueError('target_tensors is not supported with '
- 'DistributionStrategy.')
+ if distribute:
+ if not isinstance(
+ optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
+ raise NotImplementedError(
+ 'optimizer must be an instance of '
+ 'tf.train.Optimizer, not a %s' % type(optimizer))
+ if context.executing_eagerly():
+ raise NotImplementedError('DistributionStrategy is not supported '
+ 'when eager execution is enabled.')
+ if sample_weight_mode:
+ raise NotImplementedError('sample_weight_mode is not supported with '
+ 'DistributionStrategy.')
+ if weighted_metrics:
+ raise NotImplementedError('weighted_metrics is not supported with '
+ 'DistributionStrategy.')
+ if target_tensors:
+ raise ValueError('target_tensors is not supported with '
+ 'DistributionStrategy.')
loss = loss or {}
if context.executing_eagerly() and not isinstance(
optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
- raise ValueError('Only TF native optimizers are supported in Eager mode.')
+ raise ValueError(
+ 'optimizer must be an instance of tf.train.Optimizer, not '
+ 'a %s' % type(optimizer))
self.optimizer = optimizers.get(optimizer)
# We've disabled automatic dependency tracking for this method, but do want
@@ -407,8 +426,9 @@ class Model(Network):
# Set DistributionStrategy specific parameters.
self._distribution_strategy = distribute
+ # Reset the value of grouped_model
+ self._grouped_model = None
if self._distribution_strategy is not None:
- self._grouped_model = None
distributed_training_utils.configure_and_create_session(
self._distribution_strategy)
if not self.built:
@@ -430,7 +450,8 @@ class Model(Network):
for name in self.output_names:
if name not in loss:
logging.warning(
- 'Output "' + name + '" missing from loss dictionary. We assume '
+ 'Output "' + name +
+ '" missing from loss dictionary. We assume '
'this was done on purpose. The fit and evaluate APIs will not be '
'expecting any data to be passed to "' + name + '".')
loss_functions.append(losses.get(loss.get(name)))
@@ -492,24 +513,15 @@ class Model(Network):
self.loss_weights_list = loss_weights_list
# Initialize model metric attributes.
- self.metrics_names = ['loss']
- self.metrics_tensors = []
- self.metrics_updates = []
- self.stateful_metric_names = []
- self.stateful_metric_functions = []
-
- # Nested metrics is a list of list of metrics.
- # One list per output of the model.
- self.nested_metrics = training_utils.collect_metrics(
- metrics, self.output_names)
- self.nested_weighted_metrics = training_utils.collect_metrics(
- weighted_metrics, self.output_names)
+ self._init_metric_attributes()
# Initialization for Eager mode execution.
if context.executing_eagerly():
# Prepare sample weights.
self._set_sample_weight_attributes(sample_weight_mode,
skip_target_weighing_indices)
+ # Save all metric attributes per output of the model.
+ self._cache_output_metric_attributes(metrics, weighted_metrics)
if target_tensors is not None:
raise ValueError('target_tensors are not currently supported in Eager '
@@ -520,10 +532,10 @@ class Model(Network):
self.metrics_names.append(self.output_names[i] + '_loss')
# Set metric attributes on model.
- self._handle_metrics(
+ self._set_metric_attributes(
self.outputs,
skip_target_indices=skip_target_indices,
- sample_weights=self.sample_weights)
+ )
self.targets = []
for i in range(len(self.outputs)):
@@ -586,6 +598,8 @@ class Model(Network):
# Prepare sample weights.
self._set_sample_weight_attributes(sample_weight_mode,
skip_target_weighing_indices)
+ # Save all metric attributes per output of the model.
+ self._cache_output_metric_attributes(metrics, weighted_metrics)
# Compute total loss.
total_loss = None
@@ -620,6 +634,11 @@ class Model(Network):
for loss_tensor in self.losses:
total_loss += loss_tensor
+ # Set metric attributes on model.
+ self._set_metric_attributes(
+ self.outputs,
+ skip_target_indices=skip_target_indices,
+ )
# Invoke metric functions for all the outputs.
self._handle_metrics(
self.outputs,
@@ -1338,6 +1357,9 @@ class Model(Network):
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False,
**kwargs):
"""Trains the model for a fixed number of epochs (iterations on a dataset).
@@ -1350,19 +1372,23 @@ class Model(Network):
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator. Should return a tuple
- of either (inputs, targets) or (inputs, targets, sample_weights).
+ of either `(inputs, targets)` or
+ `(inputs, targets, sample_weights)`.
+ - A generator or `keras.utils.Sequence` returning `(inputs, targets)`
+ or `(inputs, targets, sample weights)`.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
- tensor targets, or inversely). If `x` is a dataset or dataset
- iterator, `y` should not be specified
- (since targets will be obtained from the iterator).
+ tensor targets, or inversely). If `x` is a dataset, dataset
+ iterator, generator, or `keras.utils.Sequence` instance, `y` should
+ not be specified (since targets will be obtained from `x`).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
- form of symbolic tensors, datasets, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
epochs: Integer. Number of epochs to train the model.
An epoch is an iteration over the entire `x` and `y`
data provided.
@@ -1384,7 +1410,8 @@ class Model(Network):
on this data at the end of each epoch.
The validation data is selected from the last samples
in the `x` and `y` data provided, before shuffling. This argument is
- not supported when `x` is a dataset or a dataset iterator.
+ not supported when `x` is a dataset, dataset iterator, generator or
+ `keras.utils.Sequence` instance.
validation_data: Data on which to evaluate
the loss and any model metrics at the end of each epoch.
The model will not be trained on this data.
@@ -1415,8 +1442,9 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator, instead
- provide the sample_weights as the third element of `x`.
+ supported when `x` is a dataset, dataset iterator, generator, or
+ `keras.utils.Sequence` instance, instead provide the sample_weights
+ as the third element of `x`.
initial_epoch: Integer.
Epoch at which to start training
(useful for resuming a previous training run).
@@ -1430,6 +1458,20 @@ class Model(Network):
validation_steps: Only relevant if `steps_per_epoch`
is specified. Total number of steps (batches of samples)
to validate before stopping.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up
+ when using process-based threading. If unspecified, `workers`
+ will default to 1. If 0, will execute the generator on the main
+ thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
**kwargs: Used for backwards compatibility.
Returns:
@@ -1446,6 +1488,23 @@ class Model(Network):
# TODO(fchollet): this method may be creating reference cycles, which would
# lead to accumulating garbage in memory when called in a loop. Investigate.
+ if data_utils.is_generator_or_sequence(x):
+ training_utils.check_generator_arguments(y, sample_weight)
+ return self.fit_generator(
+ x,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ validation_data=validation_data,
+ validation_steps=validation_steps,
+ class_weight=class_weight,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing,
+ shuffle=shuffle,
+ initial_epoch=initial_epoch)
+
# Backwards compatibility
if batch_size is None and steps_per_epoch is None:
batch_size = 32
@@ -1462,7 +1521,8 @@ class Model(Network):
if self._distribution_strategy:
distributed_training_utils.validate_callbacks(callbacks)
- distributed_training_utils.validate_inputs(x, y)
+ distributed_training_utils.validate_inputs(
+ x, y, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if not steps_per_epoch and isinstance(first_x_value, np.ndarray):
@@ -1504,7 +1564,8 @@ class Model(Network):
# Validate and standardize validation data.
if self._distribution_strategy:
- distributed_training_utils.validate_inputs(val_x, val_y)
+ distributed_training_utils.validate_inputs(
+ val_x, val_y, self._distribution_strategy)
first_valx_value = nest.flatten(val_x)[0]
if not validation_steps and isinstance(first_valx_value, np.ndarray):
validation_steps = distributed_training_utils.get_input_batch_params(
@@ -1588,7 +1649,10 @@ class Model(Network):
batch_size=None,
verbose=1,
sample_weight=None,
- steps=None):
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
"""Returns the loss value & metrics values for the model in test mode.
Computation is done in batches.
@@ -1602,18 +1666,21 @@ class Model(Network):
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
+ - A generator or `keras.utils.Sequence` instance.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely).
- If `x` is a dataset or a dataset iterator, `y` should not be specified
- (since targets will be obtained from the iterator/dataset).
+ If `x` is a dataset, dataset iterator, generator or
+ `keras.utils.Sequence` instance, `y` should not be specified (since
+ targets will be obtained from the iterator/dataset).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
- form of symbolic tensors, datasets, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
verbose: 0 or 1. Verbosity mode.
0 = silent, 1 = progress bar.
sample_weight: Optional Numpy array of weights for
@@ -1627,11 +1694,25 @@ class Model(Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
- supported when `x` is a dataset or a dataset iterator.
+ supported when `x` is a dataset or a dataset iterator, instead pass
+ sample weights as the third element of `x`.
steps: Integer or `None`.
Total number of steps (batches of samples)
before declaring the evaluation round finished.
Ignored with the default value of `None`.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up when using
+ process-based threading. If unspecified, `workers` will default
+ to 1. If 0, will execute the generator on the main thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
Returns:
Scalar test loss (if the model has a single output and no metrics)
@@ -1642,13 +1723,24 @@ class Model(Network):
Raises:
ValueError: in case of invalid arguments.
"""
+ if data_utils.is_generator_or_sequence(x):
+ training_utils.check_generator_arguments(y, sample_weight)
+ return self.evaluate_generator(
+ x,
+ steps=steps,
+ verbose=verbose,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
# Backwards compatibility.
if batch_size is None and steps is None:
batch_size = 32
# Validate and standardize user data.
if self._distribution_strategy:
- distributed_training_utils.validate_inputs(x, y)
+ distributed_training_utils.validate_inputs(
+ x, y, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray) and not steps:
steps = distributed_training_utils.get_input_batch_params(
@@ -1688,7 +1780,14 @@ class Model(Network):
verbose=verbose,
steps=steps)
- def predict(self, x, batch_size=None, verbose=0, steps=None):
+ def predict(self,
+ x,
+ batch_size=None,
+ verbose=0,
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
"""Generates output predictions for the input samples.
Computation is done in batches.
@@ -1700,16 +1799,32 @@ class Model(Network):
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A `tf.data` dataset or a dataset iterator.
+ - A generator or `keras.utils.Sequence` instance.
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
- form of symbolic tensors, dataset, or dataset iterators
- (since they generate batches).
+ form of symbolic tensors, dataset, dataset iterators,
+ generators, or `keras.utils.Sequence` instances (since they generate
+ batches).
verbose: Verbosity mode, 0 or 1.
steps: Total number of steps (batches of samples)
before declaring the prediction round finished.
Ignored with the default value of `None`.
+ max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
+ input only. Maximum size for the generator queue.
+ If unspecified, `max_queue_size` will default to 10.
+ workers: Integer. Used for generator or `keras.utils.Sequence` input
+ only. Maximum number of processes to spin up when using
+ process-based threading. If unspecified, `workers` will default
+ to 1. If 0, will execute the generator on the main thread.
+ use_multiprocessing: Boolean. Used for generator or
+ `keras.utils.Sequence` input only. If `True`, use process-based
+ threading. If unspecified, `use_multiprocessing` will default to
+ `False`. Note that because this implementation relies on
+ multiprocessing, you should not pass non-picklable arguments to
+ the generator as they can't be passed easily to children processes.
+
Returns:
Numpy array(s) of predictions.
@@ -1720,6 +1835,15 @@ class Model(Network):
or in case a stateful model receives a number of samples
that is not a multiple of the batch size.
"""
+ if data_utils.is_generator_or_sequence(x):
+ return self.predict_generator(
+ x,
+ steps=steps,
+ verbose=verbose,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
# Backwards compatibility.
if batch_size is None and steps is None:
batch_size = 32
@@ -1731,7 +1855,8 @@ class Model(Network):
# `MirroredStrategy`.
if hasattr(self._distribution_strategy, '_prefetch_on_device'):
self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
- distributed_training_utils.validate_inputs(x, None)
+ distributed_training_utils.validate_inputs(
+ x, None, self._distribution_strategy)
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray) and not steps:
steps = distributed_training_utils.get_input_batch_params(
@@ -2071,7 +2196,7 @@ class Model(Network):
Arguments:
generator: Generator yielding tuples (inputs, targets)
or (inputs, targets, sample_weights)
- or an instance of Sequence (keras.utils.Sequence)
+ or an instance of `keras.utils.Sequence`
object in order to avoid duplicate data
when using multiprocessing.
steps: Total number of steps (batches of samples)
@@ -2135,9 +2260,8 @@ class Model(Network):
Arguments:
generator: Generator yielding batches of input samples
- or an instance of Sequence (keras.utils.Sequence)
- object in order to avoid duplicate data
- when using multiprocessing.
+ or an instance of `keras.utils.Sequence` object in order to
+ avoid duplicate data when using multiprocessing.
steps: Total number of steps (batches of samples)
to yield from `generator` before stopping.
Optional for `Sequence`: if unspecified, will use
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 53291c3956..8b434ca444 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
@@ -232,8 +233,6 @@ def _experimental_fit_loop(
"""
current_strategy = model._distribution_strategy
- # TODO(priyag): Add validation that shapes are fully defined for TPU case.
-
K.get_session().run(current_strategy.initialize())
def _per_device_train_function(model):
@@ -292,11 +291,16 @@ def _experimental_fit_loop(
for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors):
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ if steps_per_epoch is None:
+ raise ValueError('steps_per_epoch should be specified in the fit call.')
+ steps_per_run_var = K.variable(
+ value=min(steps_per_epoch, current_strategy.steps_per_run),
+ dtype='int32',
+ name='steps_per_run_var')
+
with current_strategy.scope():
- # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on
- # steps_per_epoch and number of epochs.
ctx = current_strategy.run_steps_on_dataset(
- step_fn, iterator, iterations=current_strategy.steps_per_run,
+ step_fn, iterator, iterations=steps_per_run_var,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
@@ -308,14 +312,6 @@ def _experimental_fit_loop(
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
-
- assert steps_per_epoch is not None
-
- # TODO(sourabhbajaj): Convert this into a proper validation function
- if callbacks:
- raise NotImplementedError(
- 'Callbacks are not supported with TPUStrategy right now.')
-
callbacks = cbks.configure_callbacks(
callbacks,
model,
@@ -326,17 +322,26 @@ def _experimental_fit_loop(
steps_per_epoch=steps_per_epoch,
verbose=verbose)
# TODO(priyag, sourabhbajaj): Add callbacks support for per step callback
- # TODO(priyag, sourabhbajaj): Fix the number of steps run with steps_per_run
# TODO(priyag, sourabhbajaj): Add validation.
+
+ # Calculate the steps each time on the device.
+ steps_to_run = [current_strategy.steps_per_run] * (
+ steps_per_epoch // current_strategy.steps_per_run)
+ if steps_per_epoch % current_strategy.steps_per_run:
+ steps_to_run.append(steps_per_epoch % current_strategy.steps_per_run)
+
callbacks.on_train_begin()
for epoch in range(initial_epoch, epochs):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
- for step_index in range(0, steps_per_epoch, current_strategy.steps_per_run):
- # TODO(sourabhbajaj): Replace size with a combination of steps_per_run
- # and batch_size
- batch_logs = {'batch': step_index, 'size': 1}
+ step_index = 0
+ prev_step_count = None
+ for step_count in steps_to_run:
+ batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
callbacks.on_batch_begin(step_index, batch_logs)
+ if prev_step_count is None or step_count != prev_step_count:
+ steps_per_run_var.load(step_count, K.get_session())
+ prev_step_count = step_count
try:
_, outputs = K.get_session().run([train_op, output_tensors])
except errors.OutOfRangeError:
@@ -349,6 +354,7 @@ def _experimental_fit_loop(
batch_logs.update(outputs)
callbacks.on_batch_end(step_index, batch_logs)
+ step_index = step_index + step_count
if callbacks.model.stop_training:
break
@@ -742,8 +748,9 @@ def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
for name, tensor in zip(model.output_names, model.outputs):
# TODO(priyag): This is a workaround as we do not know the batch dimension
# of the model's output at this point.
- tensor.shape.dims = [batch_dimension] + tensor.shape.dims[1:]
- initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
+ shape = tensor_shape.TensorShape(tensor.shape.dims)
+ shape.dims = [batch_dimension] + shape.dims[1:]
+ initial_loop_values[name] = array_ops.zeros(shape, tensor.dtype)
with current_strategy.scope():
# TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index 939a7f2356..fb71bf2596 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -186,7 +186,7 @@ def iterator_fit_loop(model,
# make sure either x,y or x,y,sample_weights is provided
if (not isinstance(inputs.output_shapes, (list, tuple)) or
len(inputs.output_shapes) not in (2, 3)):
- raise ValueError('Please provide either inputs and targets'
+ raise ValueError('Please provide either inputs and targets '
'or inputs, targets, and sample_weights')
for step_index in range(steps_per_epoch):
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 8938333b1a..30be4131a4 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -1322,6 +1322,57 @@ class TestGeneratorMethods(test.TestCase):
workers=0,
use_multiprocessing=False)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_generator_input_to_fit_eval_predict(self):
+ val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ def custom_generator():
+ while True:
+ yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ inputs = keras.layers.Input(shape=(10,))
+ x = keras.layers.Dense(10, activation='relu')(inputs)
+ outputs = keras.layers.Dense(1, activation='sigmoid')(x)
+ model = keras.Model(inputs, outputs)
+
+ model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
+ model.fit(
+ custom_generator(),
+ steps_per_epoch=2,
+ validation_data=val_data,
+ epochs=2)
+ model.evaluate(custom_generator(), steps=2)
+ model.predict(custom_generator(), steps=2)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_sequence_input_to_fit_eval_predict(self):
+ val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ class CustomSequence(keras.utils.Sequence):
+
+ def __getitem__(self, idx):
+ return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
+
+ def __len__(self):
+ return 2
+
+ inputs = keras.layers.Input(shape=(10,))
+ x = keras.layers.Dense(10, activation='relu')(inputs)
+ outputs = keras.layers.Dense(1, activation='sigmoid')(x)
+ model = keras.Model(inputs, outputs)
+
+ model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
+ model.fit(CustomSequence(), validation_data=val_data, epochs=2)
+ model.evaluate(CustomSequence())
+ model.predict(CustomSequence())
+
+ with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'):
+ model.fit(CustomSequence(), y=np.ones([10, 1]))
+
+ with self.assertRaisesRegexp(ValueError,
+ '`sample_weight` argument is not supported'):
+ model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
+
class TestTrainingUtils(test.TestCase):
@@ -2205,7 +2256,26 @@ class TestTrainingWithMetrics(test.TestCase):
'dense_binary_accuracy', 'dropout_mean_squared_error',
'dropout_binary_accuracy'
]
+ reference_stateful_metric_names = [
+ 'dense_binary_accuracy', 'dropout_binary_accuracy'
+ ]
+ self.assertEqual(reference_metric_names, model.metrics_names)
+ self.assertEqual(reference_stateful_metric_names,
+ model.stateful_metric_names)
+
+ # Verify that model metric names are not altered during training.
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5)
self.assertEqual(reference_metric_names, model.metrics_names)
+ self.assertEqual(reference_stateful_metric_names,
+ model.stateful_metric_names)
@tf_test_util.run_in_graph_and_eager_modes
def test_metrics_correctness(self):
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 898e9223cb..9c303f4bed 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from collections import OrderedDict
import copy
import math
@@ -484,29 +485,36 @@ def check_loss_and_target_compatibility(targets, loss_fns, output_shapes):
'as the output.')
-def collect_metrics(metrics, output_names):
- """Maps metric functions to model outputs.
+def collect_per_output_metric_info(metrics,
+ output_names,
+ output_shapes,
+ loss_fns,
+ sample_weights=None):
+ """Maps metric names and functions to model outputs.
Arguments:
metrics: a list or dict of metric functions.
output_names: a list of the names (strings) of model outputs.
+ output_shapes: a list of the shapes (strings) of model outputs.
+ loss_fns: a list of the loss functions corresponding to the model outputs.
+ sample_weights: a list of weights to be applied on the model outputs.
Returns:
- A list (one entry per model output) of lists of metric functions.
+ A list (one entry per model output) of dicts.
For instance, if the model has 2 outputs, and for the first output
we want to compute "binary_accuracy" and "binary_crossentropy",
and just "binary_accuracy" for the second output,
- the list would look like:
- `[[binary_accuracy, binary_crossentropy], [binary_accuracy]]`
+ the list would look like: `[[('acc', binary_accuracy()),
+ ('ce', binary_crossentropy())], [('acc', binary_accuracy())]]`
Raises:
TypeError: if an incorrect type is passed for the `metrics` argument.
"""
if not metrics:
- return [[] for _ in output_names]
+ return [{} for _ in output_names]
if isinstance(metrics, list):
# we then apply all metrics to all outputs.
- return [copy.copy(metrics) for _ in output_names]
+ nested_metrics = [copy.copy(metrics) for _ in output_names]
elif isinstance(metrics, dict):
nested_metrics = []
for name in output_names:
@@ -514,11 +522,24 @@ def collect_metrics(metrics, output_names):
if not isinstance(output_metrics, list):
output_metrics = [output_metrics]
nested_metrics.append(output_metrics)
- return nested_metrics
else:
raise TypeError('Type of `metrics` argument not understood. '
'Expected a list or dictionary, found: ' + str(metrics))
+ per_output_metrics = []
+ for i, metrics in enumerate(nested_metrics):
+ metrics_dict = OrderedDict()
+ for metric in metrics:
+ weighted = False if (sample_weights is None) else (
+ sample_weights[i] is not None)
+ metric_name = get_metric_name(metric, weighted)
+ metric_fn = get_metric_function(
+ metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
+ metrics_dict[metric_name] = metric_fn
+ per_output_metrics.append(metrics_dict)
+
+ return per_output_metrics
+
def batch_shuffle(index_array, batch_size):
"""Shuffles an array in a batch-wise fashion.
@@ -729,6 +750,33 @@ def has_tensors(ls):
return tensor_util.is_tensor(ls)
+def get_metric_name(metric, weighted=False):
+ """Returns the name corresponding to the given metric input.
+
+ Arguments:
+ metric: Metric function name or reference.
+ weighted: Boolean indicating if the given metric is weighted.
+
+ Returns:
+ The metric name.
+ """
+ metric_name_prefix = 'weighted_' if weighted else ''
+ if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):
+ if metric in ('accuracy', 'acc'):
+ suffix = 'acc'
+ elif metric in ('crossentropy', 'ce'):
+ suffix = 'ce'
+ else:
+ metric_fn = metrics_module.get(metric)
+ # Get metric name as string
+ if hasattr(metric_fn, 'name'):
+ suffix = metric_fn.name
+ else:
+ suffix = metric_fn.__name__
+ metric_name = metric_name_prefix + suffix
+ return metric_name
+
+
def get_metric_function(metric, output_shape=None, loss_fn=None):
"""Returns the metric function corresponding to the given metric input.
@@ -797,6 +845,18 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None):
'Received: x=%s, validation_split=%f' % (x, validation_split))
+def check_generator_arguments(y=None, sample_weight=None):
+ """Validates arguments passed when using a generator."""
+ if y is not None:
+ raise ValueError('`y` argument is not supported when data is'
+ 'a generator or Sequence instance. Instead pass targets'
+ ' as the second element of the generator.')
+ if sample_weight is not None:
+ raise ValueError('`sample_weight` argument is not supported when data is'
+ 'a generator or Sequence instance. Instead pass sample'
+ ' weights as the third element of the generator.')
+
+
def check_steps_argument(input_data, steps, steps_name):
"""Validates `steps` argument based on input data's type.
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index 61ab69c16f..a2385dfdbb 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
@@ -268,7 +267,7 @@ class Softmax(Layer):
self.axis = axis
def call(self, inputs):
- return activations.softmax(inputs, axis=self.axis)
+ return K.softmax(inputs, axis=self.axis)
def get_config(self):
config = {'axis': self.axis}
@@ -315,18 +314,19 @@ class ReLU(Layer):
'cannot be negative value: ' + str(negative_slope))
self.support_masking = True
- self.max_value = K.cast_to_floatx(max_value)
+ if max_value is not None:
+ max_value = K.cast_to_floatx(max_value)
+ self.max_value = max_value
self.negative_slope = K.cast_to_floatx(negative_slope)
self.threshold = K.cast_to_floatx(threshold)
def call(self, inputs):
# alpha is used for leaky relu slope in activations instead of
# negative_slope.
- return activations.relu(
- inputs,
- alpha=self.negative_slope,
- max_value=self.max_value,
- threshold=self.threshold)
+ return K.relu(inputs,
+ alpha=self.negative_slope,
+ max_value=self.max_value,
+ threshold=self.threshold)
def get_config(self):
config = {
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index b020b6e730..c41087be0a 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -67,6 +67,14 @@ class AdvancedActivationsTest(test.TestCase):
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': 10},
input_shape=(2, 3, 4))
+ x = keras.backend.ones((3, 4))
+ # Test that we use `leaky_relu` when appropriate in graph mode.
+ self.assertTrue(
+ 'LeakyRelu' in keras.layers.ReLU(negative_slope=0.2)(x).name)
+ # Test that we use `relu` when appropriate in graph mode.
+ self.assertTrue('Relu' in keras.layers.ReLU()(x).name)
+ # Test that we use `relu6` when appropriate in graph mode.
+ self.assertTrue('Relu6' in keras.layers.ReLU(max_value=6)(x).name)
def test_relu_with_invalid_arg(self):
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index a57ac121ed..d00def07bb 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -64,7 +64,7 @@ class Conv(Layer):
specifying the stride length of the convolution.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
+ padding: One of `"valid"`, `"same"`, or `"causal"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
@@ -126,6 +126,10 @@ class Conv(Layer):
kernel_size, rank, 'kernel_size')
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
self.padding = conv_utils.normalize_padding(padding)
+ if (self.padding == 'causal' and not isinstance(self,
+ (Conv1D, SeparableConv1D))):
+ raise ValueError('Causal padding is only supported for `Conv1D`'
+ 'and ``SeparableConv1D`.')
self.data_format = conv_utils.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(
dilation_rate, rank, 'dilation_rate')
@@ -172,12 +176,16 @@ class Conv(Layer):
self.bias = None
self.input_spec = InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
+ if self.padding == 'causal':
+ op_padding = 'valid'
+ else:
+ op_padding = self.padding
self._convolution_op = nn_ops.Convolution(
input_shape,
filter_shape=self.kernel.get_shape(),
dilation_rate=self.dilation_rate,
strides=self.strides,
- padding=self.padding.upper(),
+ padding=op_padding.upper(),
data_format=conv_utils.convert_data_format(self.data_format,
self.rank + 2))
self.built = True
@@ -264,6 +272,15 @@ class Conv(Layer):
base_config = super(Conv, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+ def _compute_causal_padding(self):
+ """Calculates padding for 'causal' option for 1-d conv layers."""
+ left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)
+ if self.data_format == 'channels_last':
+ causal_padding = [[0, 0], [left_pad, 0], [0, 0]]
+ else:
+ causal_padding = [[0, 0], [0, 0], [left_pad, 0]]
+ return causal_padding
+
@tf_export('keras.layers.Conv1D', 'keras.layers.Convolution1D')
class Conv1D(Conv):
@@ -361,6 +378,11 @@ class Conv1D(Conv):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ def call(self, inputs):
+ if self.padding == 'causal':
+ inputs = array_ops.pad(inputs, self._compute_causal_padding())
+ return super(Conv1D, self).call(inputs)
+
@tf_export('keras.layers.Conv2D', 'keras.layers.Convolution2D')
class Conv2D(Conv):
@@ -1261,31 +1283,44 @@ class SeparableConv(Conv):
def get_config(self):
config = {
- 'filters': self.filters,
- 'kernel_size': self.kernel_size,
- 'strides': self.strides,
- 'padding': self.padding,
- 'data_format': self.data_format,
- 'dilation_rate': self.dilation_rate,
- 'activation': activations.serialize(self.activation),
- 'use_bias': self.use_bias,
+ 'filters':
+ self.filters,
+ 'kernel_size':
+ self.kernel_size,
+ 'strides':
+ self.strides,
+ 'padding':
+ self.padding,
+ 'data_format':
+ self.data_format,
+ 'depth_multiplier':
+ self.depth_multiplier,
+ 'dilation_rate':
+ self.dilation_rate,
+ 'activation':
+ activations.serialize(self.activation),
+ 'use_bias':
+ self.use_bias,
'depthwise_initializer':
initializers.serialize(self.depthwise_initializer),
'pointwise_initializer':
initializers.serialize(self.pointwise_initializer),
- 'bias_initializer': initializers.serialize(self.bias_initializer),
+ 'bias_initializer':
+ initializers.serialize(self.bias_initializer),
'depthwise_regularizer':
regularizers.serialize(self.depthwise_regularizer),
'pointwise_regularizer':
regularizers.serialize(self.pointwise_regularizer),
- 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer':
+ regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
'depthwise_constraint':
constraints.serialize(self.depthwise_constraint),
'pointwise_constraint':
constraints.serialize(self.pointwise_constraint),
- 'bias_constraint': constraints.serialize(self.bias_constraint)
+ 'bias_constraint':
+ constraints.serialize(self.bias_constraint)
}
base_config = super(SeparableConv, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -1311,7 +1346,7 @@ class SeparableConv1D(SeparableConv):
of the convolution.
Specifying any `stride` value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
- padding: One of `"valid"` or `"same"` (case-insensitive).
+ padding: One of `"valid"`, `"same"`, or `"causal"` (case-insensitive).
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
@@ -1397,6 +1432,8 @@ class SeparableConv1D(SeparableConv):
**kwargs)
def call(self, inputs):
+ if self.padding == 'causal':
+ inputs = array_ops.pad(inputs, self._compute_causal_padding())
if self.data_format == 'channels_last':
strides = (1,) + self.strides * 2 + (1,)
spatial_start_dim = 1
@@ -1411,12 +1448,16 @@ class SeparableConv1D(SeparableConv):
pointwise_kernel = array_ops.expand_dims(self.pointwise_kernel, 0)
dilation_rate = (1,) + self.dilation_rate
+ if self.padding == 'causal':
+ op_padding = 'valid'
+ else:
+ op_padding = self.padding
outputs = nn.separable_conv2d(
inputs,
depthwise_kernel,
pointwise_kernel,
strides=strides,
- padding=self.padding.upper(),
+ padding=op_padding.upper(),
rate=dilation_rate,
data_format=conv_utils.convert_data_format(self.data_format, ndim=4))
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index f904744422..2d3d38a5ce 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -52,7 +52,7 @@ class Convolution1DTest(test.TestCase):
'kernel_size': 3,
}
- self._run_test(kwargs, 'padding', ['valid', 'same'])
+ self._run_test(kwargs, 'padding', ['valid', 'same', 'causal'])
self._run_test(kwargs, 'strides', [2])
self._run_test(kwargs, 'dilation_rate', [2])
@@ -329,7 +329,7 @@ class SeparableConv1DTest(test.TestCase):
'kernel_size': 3,
}
- self._run_test(kwargs, 'padding', ['valid', 'same'])
+ self._run_test(kwargs, 'padding', ['valid', 'same', 'causal'])
self._run_test(kwargs, 'strides', [2])
self._run_test(kwargs, 'dilation_rate', [2])
self._run_test(kwargs, 'depth_multiplier', [2])
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 629a9ec9a1..c6df5f2e26 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
@@ -117,12 +119,27 @@ class Embedding(Layer):
@tf_utils.shape_type_conversion
def build(self, input_shape):
- self.embeddings = self.add_weight(
- shape=(self.input_dim, self.output_dim),
- initializer=self.embeddings_initializer,
- name='embeddings',
- regularizer=self.embeddings_regularizer,
- constraint=self.embeddings_constraint)
+ # Note: most sparse optimizers do not have GPU kernels defined. When
+ # building graphs, the placement algorithm is able to place variables on CPU
+ # since it knows all kernels using the variable only exist on CPU.
+ # When eager execution is enabled, the placement decision has to be made
+ # right now. Checking for the presence of GPUs to avoid complicating the
+ # TPU codepaths which can handle sparse optimizers.
+ if context.executing_eagerly() and context.context().num_gpus():
+ with ops.device('cpu:0'):
+ self.embeddings = self.add_weight(
+ shape=(self.input_dim, self.output_dim),
+ initializer=self.embeddings_initializer,
+ name='embeddings',
+ regularizer=self.embeddings_regularizer,
+ constraint=self.embeddings_constraint)
+ else:
+ self.embeddings = self.add_weight(
+ shape=(self.input_dim, self.output_dim),
+ initializer=self.embeddings_initializer,
+ name='embeddings',
+ regularizer=self.embeddings_regularizer,
+ constraint=self.embeddings_constraint)
self.built = True
def compute_mask(self, inputs, mask=None):
diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py
index cab176ee34..2e42e403aa 100644
--- a/tensorflow/python/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/layers/embeddings_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad
class EmbeddingTest(test.TestCase):
@@ -78,6 +80,17 @@ class EmbeddingTest(test.TestCase):
outputs = keras.backend.eval(layer(inputs))
self.assertAllClose(outputs, [[[1, 1], [2, 2], [1, 1]]])
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_eager_gpu_cpu(self):
+ l = keras.layers.Embedding(output_dim=2, input_dim=2)
+ l.build((None, 2))
+ inputs = keras.backend.constant([[0, 1, 0]], dtype='int32')
+ with backprop.GradientTape() as tape:
+ output = l(inputs)
+ gs = tape.gradient(output, l.weights)
+ opt = adagrad.AdagradOptimizer(0.1)
+ opt.apply_gradients(zip(gs, l.weights))
+ self.assertAllEqual(len(gs), 1)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index a3861e44d5..b9e90095e4 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -530,7 +530,9 @@ class RNNTest(test.TestCase):
y_np_2 = model.predict(x_np)
self.assertAllClose(y_np, y_np_2, atol=1e-4)
- def test_stacked_rnn_dropout(self):
+ def DISABLED_test_stacked_rnn_dropout(self):
+ # Temporarily disabled test due an occasional Grappler segfault.
+ # See b/115523414
cells = [keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
layer = keras.layers.RNN(cells)
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 473d8cd95b..e64241e5cf 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -199,7 +199,6 @@ def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
# squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
y_true, y_pred)
- y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
if sample_weight is None:
return y_pred, y_true, None
@@ -342,19 +341,14 @@ class Metric(Layer):
# weak reference. This is to remove reference cycle that is created here.
# This is not an issue in python versions > 3.
if context.executing_eagerly():
- update_state = weakmethod(obj.update_state)
- else:
- update_state = function.defun(obj.update_state)
+ obj.update_state = weakmethod(obj.update_state)
obj.update_state = weakmethod(
- types.MethodType(update_state_wrapper(update_state), obj))
+ types.MethodType(update_state_wrapper(obj.update_state), obj))
result = weakmethod(obj.result)
obj.result = weakmethod(types.MethodType(result_wrapper(result), obj))
else:
- # Converting update_state_fn() into a graph function, so that
- # we can return a single op that performs all of the variable updates.
- defuned_update_state_fn = function.defun(obj.update_state)
obj.update_state = types.MethodType(
- update_state_wrapper(defuned_update_state_fn), obj)
+ update_state_wrapper(obj.update_state), obj)
obj.result = types.MethodType(result_wrapper(obj.result), obj)
return obj
@@ -475,6 +469,9 @@ class Mean(Metric):
Args:
values: Per-example value.
sample_weight: Optional weighting of each example. Defaults to 1.
+
+ Returns:
+ Update op.
"""
values = math_ops.cast(values, self._dtype)
if sample_weight is None:
@@ -501,8 +498,9 @@ class Mean(Metric):
values = math_ops.reduce_sum(values)
# Update state variables
- state_ops.assign_add(self.total, values)
- state_ops.assign_add(self.count, num_values)
+ update_total_op = state_ops.assign_add(self.total, values)
+ update_count_op = state_ops.assign_add(self.count, num_values)
+ return control_flow_ops.group(update_total_op, update_count_op)
def result(self):
return safe_div(self.total, self.count)
@@ -536,6 +534,9 @@ class MeanMetricWrapper(Mean):
sample_weight: Optional weighting of each example. Defaults to 1. Can be
a `Tensor` whose rank is either 0, or the same rank as `y_true`,
and must be broadcastable to `y_true`.
+
+ Returns:
+ Update op.
"""
y_true = math_ops.cast(y_true, self._dtype)
y_pred = math_ops.cast(y_pred, self._dtype)
@@ -543,7 +544,7 @@ class MeanMetricWrapper(Mean):
y_pred, y_true, sample_weight)
matches = self._fn(y_true, y_pred, **self._fn_kwargs)
- super(MeanMetricWrapper, self).update_state(
+ return super(MeanMetricWrapper, self).update_state(
matches, sample_weight=sample_weight)
def get_config(self):
@@ -600,6 +601,23 @@ class CategoricalAccuracy(MeanMetricWrapper):
categorical_accuracy, name, dtype=dtype)
+class SparseCategoricalAccuracy(MeanMetricWrapper):
+ """Calculates how often predictions matches integer labels.
+
+ This metric creates two local variables, `total` and `count` that are used to
+ compute the frequency with which `y_pred` matches `y_true`. This frequency is
+ ultimately returned as `sparse categorical accuracy`: an idempotent operation
+ that simply divides `total` by `count`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+ """
+
+ def __init__(self, name='sparse_categorical_accuracy', dtype=None):
+ super(SparseCategoricalAccuracy, self).__init__(
+ sparse_categorical_accuracy, name, dtype=dtype)
+
+
@tf_export('keras.metrics.binary_accuracy')
def binary_accuracy(y_true, y_pred, threshold=0.5):
threshold = math_ops.cast(threshold, y_pred.dtype)
@@ -615,6 +633,7 @@ def categorical_accuracy(y_true, y_pred):
K.floatx())
+@tf_export('keras.metrics.sparse_categorical_accuracy')
def sparse_categorical_accuracy(y_true, y_pred):
y_true = math_ops.reduce_max(y_true, axis=-1)
y_pred = math_ops.argmax(y_pred, axis=-1)
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 71c1987cee..3a1b00041f 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -463,7 +463,7 @@ class ModelSubclassingTest(test.TestCase):
num_samples = 10
input_dim = 50
- with self.test_session():
+ with self.cached_session():
model = SimpleTestModel(num_classes=num_classes,
use_dp=True,
use_bn=True)
@@ -481,7 +481,7 @@ class ModelSubclassingTest(test.TestCase):
num_samples = 10
input_dim = 50
- with self.test_session():
+ with self.cached_session():
model = MultiIOTestModel(num_classes=num_classes,
use_dp=True,
use_bn=True)
@@ -501,7 +501,7 @@ class ModelSubclassingTest(test.TestCase):
num_samples = 10
input_dim = 50
- with self.test_session():
+ with self.cached_session():
model = SimpleTestModel(num_classes=num_classes, use_dp=True, use_bn=True)
model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
@@ -521,7 +521,7 @@ class ModelSubclassingTest(test.TestCase):
num_samples = 1000
input_dim = 50
- with self.test_session():
+ with self.cached_session():
model = MultiIOTestModel(num_classes=num_classes,
use_dp=True,
use_bn=True)
@@ -610,7 +610,7 @@ class ModelSubclassingTest(test.TestCase):
def call(self, x):
return self.bn(self.fc(x))
- with self.test_session():
+ with self.cached_session():
model = TestModel1()
x = array_ops.ones(shape=[100, 784], dtype='float32')
@@ -631,7 +631,7 @@ class ModelSubclassingTest(test.TestCase):
def call(self, x):
return self.bn(self.fc(x))
- with self.test_session():
+ with self.cached_session():
model = TestModel2()
x = array_ops.ones(shape=[100, 784], dtype='float32')
@@ -655,7 +655,7 @@ class ModelSubclassingTest(test.TestCase):
def call(self, x):
return self.bn(self.fc(x))
- with self.test_session():
+ with self.cached_session():
model = TestModel3()
x = array_ops.ones(shape=[100, 784], dtype='float32')
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 9a68fc0e35..9664f09fff 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -18,10 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import gc
+import weakref
+
import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
@@ -85,23 +89,23 @@ def _test_optimizer(optimizer, target=0.75):
class KerasOptimizersTest(test.TestCase):
def test_sgd(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.SGD(lr=0.01,
momentum=0.9,
nesterov=True))
def test_rmsprop(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.RMSprop())
_test_optimizer(keras.optimizers.RMSprop(decay=1e-3))
def test_adagrad(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Adagrad())
_test_optimizer(keras.optimizers.Adagrad(decay=1e-3))
def test_adadelta(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Adadelta(), target=0.6)
# Accuracy seems dependent on the initialization. Even adding tf.Print
# nodes in the graph seemed to affect the initialization seed, and hence
@@ -109,28 +113,28 @@ class KerasOptimizersTest(test.TestCase):
_test_optimizer(keras.optimizers.Adadelta(decay=1e-3), target=0.4)
def test_adam(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Adam())
_test_optimizer(keras.optimizers.Adam(decay=1e-3))
_test_optimizer(keras.optimizers.Adam(amsgrad=True))
def test_adamax(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Adamax())
_test_optimizer(keras.optimizers.Adamax(decay=1e-3))
def test_nadam(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.Nadam())
def test_clipnorm(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.SGD(lr=0.01,
momentum=0.9,
clipnorm=0.5))
def test_clipvalue(self):
- with self.test_session():
+ with self.cached_session():
_test_optimizer(keras.optimizers.SGD(lr=0.01,
momentum=0.9,
clipvalue=0.5))
@@ -156,9 +160,22 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
+ def test_optimizer_garbage_collection(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
+ keras.backend.track_tf_optimizer(optimizer)
+ optimizer_weak = weakref.ref(optimizer)
+ graph_weak = weakref.ref(graph)
+ del graph, optimizer
+ gc.collect()
+ # Check that the weak references are dead now.
+ self.assertIs(graph_weak(), None)
+ self.assertIs(optimizer_weak(), None)
+
@test_util.run_in_graph_and_eager_modes
def test_tfoptimizer_iterations(self):
- with self.test_session():
+ with self.cached_session():
optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
model = keras.models.Sequential()
model.add(keras.layers.Dense(
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 58405c550b..501b50ba5f 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -29,7 +29,8 @@ from tensorflow.python.util import tf_inspect
def get_test_data(train_samples,
test_samples,
input_shape,
- num_classes):
+ num_classes,
+ random_seed=None):
"""Generates test data to train a model on.
Arguments:
@@ -37,10 +38,13 @@ def get_test_data(train_samples,
test_samples: Integer, how many test samples to generate.
input_shape: Tuple of integers, shape of the inputs.
num_classes: Integer, number of classes for the data and targets.
+ random_seed: Integer, random seed used by numpy to generate data.
Returns:
A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
+ if random_seed is not None:
+ np.random.seed(random_seed)
num_sample = train_samples + test_samples
templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
y = np.random.randint(0, num_classes, size=(num_sample,))
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 3a176c3316..8ebca1418d 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -93,7 +93,7 @@ def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
Arguments:
input_length: integer.
filter_size: integer.
- padding: one of "same", "valid", "full".
+ padding: one of "same", "valid", "full", "causal"
stride: integer.
dilation: dilation rate, integer.
@@ -102,9 +102,9 @@ def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
"""
if input_length is None:
return None
- assert padding in {'same', 'valid', 'full'}
+ assert padding in {'same', 'valid', 'full', 'causal'}
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
- if padding == 'same':
+ if padding in ['same', 'causal']:
output_length = input_length
elif padding == 'valid':
output_length = input_length - dilated_filter_size + 1
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
index d93a7b6afc..b736daa46d 100644
--- a/tensorflow/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -40,6 +40,7 @@ from six.moves.urllib.error import URLError
from six.moves.urllib.request import urlopen
from tensorflow.python.keras.utils.generic_utils import Progbar
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -93,6 +94,11 @@ else:
from six.moves.urllib.request import urlretrieve
+def is_generator_or_sequence(x):
+ """Check if `x` is a Keras generator type."""
+ return tf_inspect.isgenerator(x) or isinstance(x, Sequence)
+
+
def _extract_archive(file_path, path='.', archive_format='auto'):
"""Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats.
@@ -551,7 +557,7 @@ class OrderedEnqueuer(SequenceEnqueuer):
self.executor_fn = lambda seqs: multiprocessing.Pool( # pylint: disable=g-long-lambda
workers, initializer=init_pool, initargs=(seqs,))
else:
- # We do not need the init since it's threads.
+ # We do not need the init since it's threads.
self.executor_fn = lambda _: ThreadPool(workers)
self.workers = workers
self.queue = queue.Queue(max_queue_size)
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index c7e94998b4..3d0351a11f 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -48,7 +48,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(hidden_dim,
input_shape=(input_dim,)))
@@ -78,7 +78,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
input_a = keras.Input((input_dim_a,))
input_b = keras.Input((input_dim_b,))
a = keras.layers.Dense(hidden_dim)(input_a)
@@ -105,7 +105,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=2):
return
- with self.test_session():
+ with self.cached_session():
input_shape = (1000, 10)
model = keras.models.Sequential()
model.add(keras.layers.Dense(10,
@@ -144,7 +144,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
input_shape = (num_samples,) + shape
x_train = np.random.randint(0, 255, input_shape)
y_train = np.random.randint(0, num_classes, (input_shape[0],))
@@ -186,7 +186,7 @@ class TestMultiGPUModel(test.TestCase):
if not check_if_compatible_devices(gpus=gpus):
return
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input((4, 3))
init_state = keras.Input((3,))
outputs = keras.layers.SimpleRNN(
diff --git a/tensorflow/python/keras/wrappers/scikit_learn_test.py b/tensorflow/python/keras/wrappers/scikit_learn_test.py
index c322efdedf..f904290803 100644
--- a/tensorflow/python/keras/wrappers/scikit_learn_test.py
+++ b/tensorflow/python/keras/wrappers/scikit_learn_test.py
@@ -102,7 +102,7 @@ def assert_regression_works(reg):
class ScikitLearnAPIWrapperTest(test.TestCase):
def test_classify_build_fn(self):
- with self.test_session():
+ with self.cached_session():
clf = keras.wrappers.scikit_learn.KerasClassifier(
build_fn=build_fn_clf,
hidden_dim=HIDDEN_DIM,
@@ -118,7 +118,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def __call__(self, hidden_dim):
return build_fn_clf(hidden_dim)
- with self.test_session():
+ with self.cached_session():
clf = keras.wrappers.scikit_learn.KerasClassifier(
build_fn=ClassBuildFnClf(),
hidden_dim=HIDDEN_DIM,
@@ -134,7 +134,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def __call__(self, hidden_dim):
return build_fn_clf(hidden_dim)
- with self.test_session():
+ with self.cached_session():
clf = InheritClassBuildFnClf(
build_fn=None,
hidden_dim=HIDDEN_DIM,
@@ -144,7 +144,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
assert_classification_works(clf)
def test_regression_build_fn(self):
- with self.test_session():
+ with self.cached_session():
reg = keras.wrappers.scikit_learn.KerasRegressor(
build_fn=build_fn_reg,
hidden_dim=HIDDEN_DIM,
@@ -160,7 +160,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def __call__(self, hidden_dim):
return build_fn_reg(hidden_dim)
- with self.test_session():
+ with self.cached_session():
reg = keras.wrappers.scikit_learn.KerasRegressor(
build_fn=ClassBuildFnReg(),
hidden_dim=HIDDEN_DIM,
@@ -176,7 +176,7 @@ class ScikitLearnAPIWrapperTest(test.TestCase):
def __call__(self, hidden_dim):
return build_fn_reg(hidden_dim)
- with self.test_session():
+ with self.cached_session():
reg = InheritClassBuildFnReg(
build_fn=None,
hidden_dim=HIDDEN_DIM,
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 0403211d92..5183e4d30c 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -286,7 +286,10 @@ tf_py_test(
srcs = ["decode_csv_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
+ "//tensorflow/python/eager:context",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:parsing_ops",
],
)
@@ -535,6 +538,21 @@ tf_py_test(
)
tf_py_test(
+ name = "logging_ops_logging_level_test",
+ size = "small",
+ srcs = ["logging_ops_logging_level_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:logging_ops",
+ ],
+ tags = [
+ "no_windows",
+ ],
+)
+
+tf_py_test(
name = "logging_ops_test",
size = "small",
srcs = ["logging_ops_test.py"],
@@ -958,6 +976,19 @@ tf_py_test(
)
tf_py_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+tf_py_test(
name = "string_join_op_test",
size = "small",
srcs = ["string_join_op_test.py"],
@@ -1011,6 +1042,7 @@ tf_py_test(
size = "small",
srcs = ["substr_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -1631,6 +1663,18 @@ cuda_py_test(
)
cuda_py_test(
+ name = "extract_volume_patches_op_test",
+ size = "small",
+ srcs = ["extract_volume_patches_op_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
+
+cuda_py_test(
name = "functional_ops_test",
size = "small",
srcs = ["functional_ops_test.py"],
@@ -2795,6 +2839,46 @@ cuda_py_test(
)
cuda_py_test(
+ name = "cwise_ops_binary_test",
+ size = "medium",
+ srcs = ["cwise_ops_binary_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:nn_grad",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+ shard_count = 50,
+)
+
+cuda_py_test(
+ name = "cwise_ops_unary_test",
+ size = "medium",
+ srcs = ["cwise_ops_unary_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:nn_grad",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+ shard_count = 50,
+)
+
+cuda_py_test(
name = "embedding_ops_test",
size = "medium",
srcs = ["embedding_ops_test.py"],
@@ -3160,3 +3244,27 @@ tf_py_test(
grpc_enabled = True,
tags = ["no_gpu"], # TODO(b/111656070)
)
+
+# TODO(b/116053459): Replace with cuda_py_test.
+tf_py_test(
+ name = "while_v2_test",
+ size = "medium",
+ srcs = ["while_v2_test.py"],
+ additional_deps = [
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients_impl",
+ "//tensorflow/python:list_ops",
+ "//tensorflow/python:tf_optimizer",
+ "//tensorflow/python:while_v2",
+ ],
+ grpc_enabled = True,
+ tags = ["no_gpu"], # TODO(b/116053459)
+)
diff --git a/tensorflow/python/kernel_tests/accumulate_n_test.py b/tensorflow/python/kernel_tests/accumulate_n_test.py
index b793906fac..0bc5268f38 100644
--- a/tensorflow/python/kernel_tests/accumulate_n_test.py
+++ b/tensorflow/python/kernel_tests/accumulate_n_test.py
@@ -76,7 +76,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
# Putting them here so that everything that exercises AccumulateNV2 is in
# one place and the default build runs all unit tests.
def testSimple(self):
- with self.test_session():
+ with self.cached_session():
random_arrays = [
np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
]
@@ -91,27 +91,27 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
self.assertAllClose(np_val, tf_val.eval())
def testZeroArgs(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
tf_val = math_ops.accumulate_n([])
tf_val.eval()
def testWrongShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
a = variables.Variable(0.2)
b = variables.Variable(0.1)
math_ops.accumulate_n([a, b], shape=[2, 2]) # Should be shape=[]
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
a = variables.Variable(np.array([0.1, 0.2]))
b = variables.Variable(np.array([[0.3], [0.4]]))
math_ops.accumulate_n([a, b])
def testWrongType(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
b = variables.Variable(0.1, dtype=np.float32)
@@ -119,7 +119,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase):
def testWrongTypeOneInput(self):
# Scenario that used to trigger a bug, even when testWrongType() worked
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
math_ops.accumulate_n([a], tensor_dtype=np.int32)
diff --git a/tensorflow/python/kernel_tests/ackermann_test.py b/tensorflow/python/kernel_tests/ackermann_test.py
index 5e0d87c783..d267e49752 100644
--- a/tensorflow/python/kernel_tests/ackermann_test.py
+++ b/tensorflow/python/kernel_tests/ackermann_test.py
@@ -34,7 +34,7 @@ class AckermannTest(test.TestCase):
self.assertEqual(len(ackermann.OP_LIST.op), 1)
self.assertEqual(ackermann.OP_LIST.op[0].name, 'Ackermann')
- with self.test_session():
+ with self.cached_session():
self.assertEqual(ackermann.ackermann().eval(), b'A(m, 0) == A(m-1, 1)')
diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py
index 1202c463e8..127d14c250 100644
--- a/tensorflow/python/kernel_tests/argmax_op_test.py
+++ b/tensorflow/python/kernel_tests/argmax_op_test.py
@@ -104,20 +104,20 @@ class ArgMaxTest(test.TestCase):
self._testDim(np.int64)
def testEmpty(self):
- with self.test_session():
+ with self.cached_session():
for op in math_ops.argmin, math_ops.argmax:
with self.assertRaisesOpError(
r"Reduction axis 0 is empty in shape \[0\]"):
op([], 0).eval()
def testDefaultAxis(self):
- with self.test_session():
+ with self.cached_session():
for op in math_ops.argmin, math_ops.argmax:
ans = op([1]).eval()
self.assertAllEqual(ans, 0)
def testOutputEmpty(self):
- with self.test_session():
+ with self.cached_session():
for op in math_ops.argmin, math_ops.argmax:
ret = op(array_ops.zeros(shape=[1, 0, 2]), axis=-1).eval()
self.assertEqual(ret.shape, (1, 0))
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index a164682227..2fe85839d0 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -50,7 +50,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
def testNonBatchMatrix(self):
matrix = [[1, 2, 3], [4, 5, 6]] # Shape (2, 3)
expected_transposed = [[1, 4], [2, 5], [3, 6]] # Shape (3, 2)
- with self.test_session():
+ with self.cached_session():
transposed = array_ops.matrix_transpose(matrix)
self.assertEqual((3, 2), transposed.get_shape())
self.assertAllEqual(expected_transposed, transposed.eval())
@@ -58,7 +58,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
def testConjugate(self):
m = [[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j, 6 + 6j]]
expected_transposed = [[1 - 1j, 4 - 4j], [2 - 2j, 5 - 5j], [3 - 3j, 6 - 6j]]
- with self.test_session():
+ with self.cached_session():
matrix = ops.convert_to_tensor(m)
transposed = array_ops.matrix_transpose(matrix, conjugate=True)
self.assertEqual((3, 2), transposed.get_shape())
@@ -71,7 +71,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
matrix_1_t = [[11, 44], [22, 55], [33, 66]]
batch_matrix = [matrix_0, matrix_1] # Shape (2, 2, 3)
expected_transposed = [matrix_0_t, matrix_1_t] # Shape (2, 3, 2)
- with self.test_session():
+ with self.cached_session():
transposed = array_ops.matrix_transpose(batch_matrix)
self.assertEqual((2, 3, 2), transposed.get_shape())
self.assertAllEqual(expected_transposed, transposed.eval())
@@ -79,7 +79,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
def testNonBatchMatrixDynamicallyDefined(self):
matrix = [[1, 2, 3], [4, 5, 6]] # Shape (2, 3)
expected_transposed = [[1, 4], [2, 5], [3, 6]] # Shape (3, 2)
- with self.test_session():
+ with self.cached_session():
matrix_ph = array_ops.placeholder(dtypes.int32)
transposed = array_ops.matrix_transpose(matrix_ph)
self.assertAllEqual(
@@ -94,7 +94,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
matrix_1_t = [[11, 44], [22, 55], [33, 66]]
batch_matrix = [matrix_0, matrix_1] # Shape (2, 2, 3)
expected_transposed = [matrix_0_t, matrix_1_t] # Shape (2, 3, 2)
- with self.test_session():
+ with self.cached_session():
batch_matrix_ph = array_ops.placeholder(dtypes.int32)
transposed = array_ops.matrix_transpose(batch_matrix_ph)
self.assertAllEqual(
@@ -105,7 +105,7 @@ class BatchMatrixTransposeTest(test_util.TensorFlowTestCase):
def testTensorWithStaticRankLessThanTwoRaisesBecauseNotAMatrix(self):
vector = [1, 2, 3]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "should be a "):
array_ops.matrix_transpose(vector)
@@ -129,7 +129,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
masked_arr = arr[:, mask]
elif axis == 2:
masked_arr = arr[:, :, mask]
- with self.test_session():
+ with self.cached_session():
masked_tensor = array_ops.boolean_mask(arr, mask, axis=axis)
# Leading dimension size of masked_tensor is always unknown until runtime
@@ -176,7 +176,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
numpy_result = arr[mask]
tf_result = array_ops.boolean_mask(arr, mask)
self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(numpy_result, tf_result.eval())
def testEmptyInput1D(self):
@@ -185,7 +185,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
numpy_result = arr[mask]
tf_result = array_ops.boolean_mask(arr, mask)
self.assertAllEqual(numpy_result.shape[1:], tf_result.get_shape()[1:])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(numpy_result, tf_result.eval())
def testEmptyOutput(self):
@@ -199,7 +199,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
def testWorksWithDimensionsEqualToNoneDuringGraphBuild(self):
# The rank of the mask tensor must be specified. This is explained
# in the docstring as well.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ph_tensor = array_ops.placeholder(dtypes.int32, shape=None)
ph_mask = array_ops.placeholder(dtypes.bool, shape=[None])
@@ -217,7 +217,7 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
def testMaskDimensionsSetToNoneRaises(self):
# The rank of the mask tensor must be specified. This is explained
# in the docstring as well.
- with self.test_session():
+ with self.cached_session():
tensor = array_ops.placeholder(dtypes.int32, shape=[None, 2])
mask = array_ops.placeholder(dtypes.bool, shape=None)
with self.assertRaisesRegexp(ValueError, "dimensions must be specified"):
@@ -226,21 +226,21 @@ class BooleanMaskTest(test_util.TensorFlowTestCase):
def testMaskHasMoreDimsThanTensorRaises(self):
mask = [[True, True], [False, False]]
tensor = [1, 2, 3, 4]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "incompatible"):
array_ops.boolean_mask(tensor, mask).eval()
def testMaskIsScalarRaises(self):
mask = True
tensor = 1
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "mask.*scalar"):
array_ops.boolean_mask(tensor, mask).eval()
def testMaskShapeDifferentThanFirstPartOfTensorShapeRaises(self):
mask = [True, True, True]
tensor = [[1, 2], [3, 4]]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "incompatible"):
array_ops.boolean_mask(tensor, mask).eval()
@@ -345,7 +345,7 @@ class ReverseV2Test(test_util.TensorFlowTestCase):
def testInvalid(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
axis = array_ops.placeholder(dtypes.int32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of valid range"):
array_ops.reverse_v2(x_np, axis).eval(feed_dict={axis: [-30]})
@@ -954,7 +954,7 @@ class StridedSliceAssignChecker(object):
class SliceAssignTest(test_util.TensorFlowTestCase):
def testInvalidSlice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
foo = constant_op.constant([1, 2, 3])
with self.assertRaisesRegexp(ValueError, "Sliced assignment"
" is only supported for variables"):
@@ -1000,7 +1000,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(
errors.FailedPreconditionError,
"Attempting to use uninitialized value Variable"):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable([1, 2])
sess.run(v[:].assign([1, 2]))
@@ -1019,7 +1019,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
v = resource_variable_ops.ResourceVariable(init_val)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(v.initializer)
with self.assertRaises(ValueError):
sess.run(v[:].assign(too_large_val))
@@ -1066,12 +1066,12 @@ class ShapeSizeRankTest(test_util.TensorFlowTestCase):
class SequenceMaskTest(test_util.TensorFlowTestCase):
def testExceptions(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "maxlen must be scalar"):
array_ops.sequence_mask([10, 20], [10, 20])
def testOneDimensionalWithMaxlen(self):
- with self.test_session():
+ with self.cached_session():
res = array_ops.sequence_mask(constant_op.constant([1, 3, 2]), 5)
self.assertAllEqual(res.get_shape(), [3, 5])
self.assertAllEqual(
@@ -1081,7 +1081,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
@test_util.enable_c_shapes
def testOneDimensionalDtypeWithoutMaxlen(self):
- with self.test_session():
+ with self.cached_session():
# test dtype and default maxlen:
res = array_ops.sequence_mask(constant_op.constant([0, 1, 4]),
dtype=dtypes.float32)
@@ -1092,7 +1092,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
@test_util.enable_c_shapes
def testOneDimensionalWithoutMaxlen(self):
- with self.test_session():
+ with self.cached_session():
res = array_ops.sequence_mask(
constant_op.constant([0, 1, 4]))
self.assertAllEqual(res.get_shape().as_list(), [3, 4])
@@ -1104,7 +1104,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
@test_util.enable_c_shapes
def testTwoDimensional(self):
- with self.test_session():
+ with self.cached_session():
res = array_ops.sequence_mask(constant_op.constant([[1, 3, 2]]), 5)
self.assertAllEqual(res.get_shape(), [1, 3, 5])
self.assertAllEqual(res.eval(), [[[True, False, False, False, False], [
@@ -1137,7 +1137,7 @@ class SequenceMaskTest(test_util.TensorFlowTestCase):
[[True, False, False, False, False], [True, True, True, False, False],
[True, True, False, False, False]])
- with self.test_session():
+ with self.cached_session():
check_dtypes(dtypes.int32, dtypes.int32)
check_dtypes(dtypes.int32, dtypes.int64)
check_dtypes(dtypes.int64, dtypes.int32)
@@ -1216,7 +1216,7 @@ class UnravelIndexTest(test_util.TensorFlowTestCase):
# TODO(b/73086570): Reenable test.
@unittest.skip("Test does not pass internally.")
def testUnravelIndex(self):
- with self.test_session():
+ with self.cached_session():
for dtype in [dtypes.int32, dtypes.int64]:
indices_1 = constant_op.constant(1621, dtype=dtype)
dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
@@ -1237,13 +1237,13 @@ class UnravelIndexTest(test_util.TensorFlowTestCase):
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
def testSimple(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.constant(10)
guarantee_a = array_ops.guarantee_const(a)
self.assertEqual(10, guarantee_a.eval())
def testVariables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for use_resource in [False, True]:
a = variable_scope.get_variable(
"var_{}".format(use_resource), [],
@@ -1254,7 +1254,7 @@ class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
self.assertEqual(10.0, guarantee_a.eval())
def testResourceRejection(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variable_scope.get_variable(
"resource_var", [],
initializer=init_ops.constant_initializer(10.0),
@@ -1276,5 +1276,203 @@ class SnapshotOpTest(test_util.TensorFlowTestCase):
self.assertAllEqual(y.eval(), [0, 1, 2, 3])
+@test_util.run_all_in_graph_and_eager_modes
+class SortedSearchTest(test_util.TensorFlowTestCase):
+
+ def testUpperBoundFloatHandCoded(self):
+ cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
+ arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
+ dtype=np.float32)
+ result = np.searchsorted(cdf, arr, side="right")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundFloatRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
+ arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundFloatUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.uniform(size=[batch_size, size_search_array]).astype(
+ np.float32),
+ axis=1)
+ arr = np.random.uniform(size=[batch_size, size_values]).astype(
+ np.float32) * size_search_array
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatHandCoded(self):
+ cdf = np.array([0, .2, .5, .6, .8, 1.], dtype=np.float32)
+ arr = np.array([.04, .99, .53, .58, .31, .01, .79, .8, .21],
+ dtype=np.float32)
+ result = np.searchsorted(cdf, arr, side="left")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.uniform(size=shape).astype(np.float32), axis=(d - 1))
+ arr = np.random.uniform(size=shape).astype(np.float32) * dim_size
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundFloatUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.uniform(size=[batch_size, size_search_array]).astype(
+ np.float32),
+ axis=1)
+ arr = np.random.uniform(size=[batch_size, size_values]).astype(
+ np.float32) * size_search_array
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntHandCoded(self):
+ cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
+ arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
+ result = np.searchsorted(cdf, arr, side="right")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10, size=shape).astype(np.int64),
+ axis=(d - 1))
+ arr = np.random.randint(
+ low=0, high=10 * dim_size, size=shape).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testUpperBoundIntUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10,
+ size=[batch_size,
+ size_search_array]).astype(np.int64),
+ axis=1)
+ arr = np.random.randint(
+ low=0, high=10 * size_search_array, size=[batch_size,
+ size_values]).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="right"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="right")
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntHandCoded(self):
+ cdf = np.array([0, 20, 50, 60, 80, 100], dtype=np.int64)
+ arr = np.array([4, 99, 53, 58, 31, 1, 79, 8, 21], dtype=np.int64)
+ result = np.searchsorted(cdf, arr, side="left")
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntRandomNd(self):
+ dim_size = 7
+ for d in range(1, 5):
+ shape = [dim_size] * d
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10, size=shape).astype(np.int64),
+ axis=(d - 1))
+ arr = np.random.randint(
+ low=0, high=10 * dim_size, size=shape).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ cdf = cdf.reshape([-1, dim_size])
+ arr = arr.reshape([-1, dim_size])
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(dim_size**(d - 1)):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ result = result.reshape(shape)
+
+ self.assertAllEqual(result, tf_result)
+
+ def testLowerBoundIntUneven(self):
+ batch_size = 7
+ size_search_array = 1000
+ size_values = 47
+ cdf = np.cumsum(
+ np.random.randint(low=0, high=10,
+ size=[batch_size,
+ size_search_array]).astype(np.int64),
+ axis=1)
+ arr = np.random.randint(
+ low=0, high=10 * size_search_array, size=[batch_size,
+ size_values]).astype(np.int64)
+
+ tf_result = self.evaluate(array_ops.searchsorted(cdf, arr, side="left"))
+
+ result = np.zeros(arr.shape, dtype=np.int32)
+ for i in range(batch_size):
+ result[i, :] = np.searchsorted(cdf[i, :], arr[i, :], side="left")
+
+ self.assertAllEqual(result, tf_result)
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/kernel_tests/as_string_op_test.py b/tensorflow/python/kernel_tests/as_string_op_test.py
index 51aa17babe..dd4a90e5f6 100644
--- a/tensorflow/python/kernel_tests/as_string_op_test.py
+++ b/tensorflow/python/kernel_tests/as_string_op_test.py
@@ -32,7 +32,7 @@ class AsStringOpTest(test.TestCase):
0, 1, -1, 0.5, 0.25, 0.125, float("INF"), float("NAN"), float("-INF")
]
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.float32, dtypes.float64):
input_ = array_ops.placeholder(dtype)
@@ -84,7 +84,7 @@ class AsStringOpTest(test.TestCase):
int_inputs_ = [0, -1, 1, -128, 127, -101, 101, -0]
s = lambda strs: [x.decode("ascii") for x in strs]
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.int32, dtypes.int64, dtypes.int8):
input_ = array_ops.placeholder(dtype)
@@ -117,7 +117,7 @@ class AsStringOpTest(test.TestCase):
# testing int8
s = lambda strs: [x.decode("ascii") for x in strs]
- with self.test_session():
+ with self.cached_session():
input_ = array_ops.placeholder(dtypes.int32)
int_inputs_ = [np.iinfo(np.int32).min, np.iinfo(np.int32).max]
output = string_ops.as_string(input_)
@@ -133,7 +133,7 @@ class AsStringOpTest(test.TestCase):
def testHalfInt(self):
s = lambda strs: [x.decode("ascii") for x in strs]
- with self.test_session():
+ with self.cached_session():
input_ = array_ops.placeholder(dtypes.int16)
int_inputs_ = [np.iinfo(np.int16).min, np.iinfo(np.int16).max]
output = string_ops.as_string(input_)
@@ -144,7 +144,7 @@ class AsStringOpTest(test.TestCase):
bool_inputs_ = [False, True]
s = lambda strs: [x.decode("ascii") for x in strs]
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.bool,):
input_ = array_ops.placeholder(dtype)
@@ -159,7 +159,7 @@ class AsStringOpTest(test.TestCase):
]
complex_inputs_ = [(x + (x + 1) * 1j) for x in float_inputs_]
- with self.test_session():
+ with self.cached_session():
for dtype in (dtypes.complex64, dtypes.complex128):
input_ = array_ops.placeholder(dtype)
diff --git a/tensorflow/python/kernel_tests/atrous_convolution_test.py b/tensorflow/python/kernel_tests/atrous_convolution_test.py
index b98e5fd386..6b16fca29d 100644
--- a/tensorflow/python/kernel_tests/atrous_convolution_test.py
+++ b/tensorflow/python/kernel_tests/atrous_convolution_test.py
@@ -263,7 +263,7 @@ class AtrousConvolutionTest(test.TestCase):
self.assertLess(err, err_tolerance)
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
for padding in ["SAME", "VALID"]:
for rate_width in range(1, 3):
for rate_height in range(1, 3):
diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py
index fb74698660..1e09ba5b65 100644
--- a/tensorflow/python/kernel_tests/attention_ops_test.py
+++ b/tensorflow/python/kernel_tests/attention_ops_test.py
@@ -84,7 +84,7 @@ class ExtractGlimpseTest(test.TestCase):
image_ops.extract_glimpse(t_cols_4d, t1, t2), [0, 2, 1, 3]))
# Evaluate the TensorFlow Graph.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value_rows, value_cols = sess.run([glimpse_rows, glimpse_cols])
# Check dimensions of returned glimpse.
@@ -118,7 +118,7 @@ class ExtractGlimpseTest(test.TestCase):
def testEmptyTensor(self):
empty_image = np.zeros((0, 4, 3, 0))
offsets = np.zeros((0, 2))
- with self.test_session():
+ with self.cached_session():
result = image_ops.extract_glimpse(empty_image, [1, 1], offsets)
self.assertAllEqual(
np.zeros(
diff --git a/tensorflow/python/kernel_tests/barrier_ops_test.py b/tensorflow/python/kernel_tests/barrier_ops_test.py
index 7f49c63957..4d36b3a465 100644
--- a/tensorflow/python/kernel_tests/barrier_ops_test.py
+++ b/tensorflow/python/kernel_tests/barrier_ops_test.py
@@ -67,7 +67,7 @@ class BarrierTest(test.TestCase):
""", b.barrier_ref.op.node_def)
def testInsertMany(self):
- with self.test_session():
+ with self.cached_session():
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -83,7 +83,7 @@ class BarrierTest(test.TestCase):
self.assertEquals(size_t.eval(), [3])
def testInsertManyEmptyTensor(self):
- with self.test_session():
+ with self.cached_session():
error_message = ("Empty tensors are not supported, but received shape "
r"\'\(0,\)\' at index 1")
with self.assertRaisesRegexp(ValueError, error_message):
@@ -91,7 +91,7 @@ class BarrierTest(test.TestCase):
(dtypes.float32, dtypes.float32), shapes=((1,), (0,)), name="B")
def testInsertManyEmptyTensorUnknown(self):
- with self.test_session():
+ with self.cached_session():
b = data_flow_ops.Barrier((dtypes.float32, dtypes.float32), name="B")
size_t = b.ready_size()
self.assertEqual([], size_t.get_shape())
@@ -103,7 +103,7 @@ class BarrierTest(test.TestCase):
insert_0_op.run()
def testTakeMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -128,7 +128,7 @@ class BarrierTest(test.TestCase):
self.assertEqual(values_1_val[idx], v1)
def testTakeManySmallBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -192,7 +192,7 @@ class BarrierTest(test.TestCase):
insert_1_3_op.run()
def testUseBarrierWithShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((2, 2), (8,)), name="B")
size_t = b.ready_size()
@@ -221,7 +221,7 @@ class BarrierTest(test.TestCase):
self.assertAllEqual(values_1_val[idx], v1)
def testParallelInsertMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
size_t = b.ready_size()
keys = [str(x).encode("ascii") for x in range(10)]
@@ -241,7 +241,7 @@ class BarrierTest(test.TestCase):
self.assertEqual(values_val[idx], v)
def testParallelTakeMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
size_t = b.ready_size()
keys = [str(x).encode("ascii") for x in range(10)]
@@ -275,7 +275,7 @@ class BarrierTest(test.TestCase):
zip(keys, values), [(k[0], v[0]) for k, v in zip(key_vals, value_vals)])
def testBlockingTakeMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(dtypes.float32, shapes=())
keys = [str(x).encode("ascii") for x in range(10)]
values = [float(x) for x in range(10)]
@@ -297,7 +297,7 @@ class BarrierTest(test.TestCase):
t.join()
def testParallelInsertManyTakeMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 100
@@ -376,7 +376,7 @@ class BarrierTest(test.TestCase):
self.assertAllEqual(taken_i["values_1"], expected_values_1)
def testClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -434,7 +434,7 @@ class BarrierTest(test.TestCase):
sess.run(take_t[0])
def testCancel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
size_t = b.ready_size()
@@ -487,7 +487,7 @@ class BarrierTest(test.TestCase):
sess.run(take_t[0])
def _testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self, cancel):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
take_t = b.take_many(1, allow_small_batch=True)
@@ -500,7 +500,7 @@ class BarrierTest(test.TestCase):
self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=True)
def _testParallelInsertManyTakeManyCloseHalfwayThrough(self, cancel):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 50
@@ -576,7 +576,7 @@ class BarrierTest(test.TestCase):
self._testParallelInsertManyTakeManyCloseHalfwayThrough(cancel=True)
def _testParallelPartialInsertManyTakeManyCloseHalfwayThrough(self, cancel):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.int64), shapes=((), (2,)))
num_iterations = 100
@@ -676,7 +676,7 @@ class BarrierTest(test.TestCase):
self._testParallelPartialInsertManyTakeManyCloseHalfwayThrough(cancel=True)
def testIncompatibleSharedBarrierErrors(self):
- with self.test_session():
+ with self.cached_session():
# Do component types and shapes.
b_a_1 = data_flow_ops.Barrier(
(dtypes.float32,), shapes=(()), shared_name="b_a")
diff --git a/tensorflow/python/kernel_tests/base64_ops_test.py b/tensorflow/python/kernel_tests/base64_ops_test.py
index be96f45497..1b399942ef 100644
--- a/tensorflow/python/kernel_tests/base64_ops_test.py
+++ b/tensorflow/python/kernel_tests/base64_ops_test.py
@@ -48,7 +48,7 @@ class Base64OpsTest(test_util.TensorFlowTestCase):
return base64_msg
def _RunTest(self, msg, pad):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if pad:
encoded, decoded = sess.run([self._encoded_t, self._decoded_t],
feed_dict={self._msg: msg})
@@ -92,7 +92,7 @@ class Base64OpsTest(test_util.TensorFlowTestCase):
encoded = string_ops.encode_base64(msg, pad=pad)
decoded = string_ops.decode_base64(encoded)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
encoded_value, decoded_value = sess.run([encoded, decoded])
self.assertEqual(encoded_value.shape, msg.shape)
@@ -102,7 +102,7 @@ class Base64OpsTest(test_util.TensorFlowTestCase):
def try_decode(enc):
self._decoded_f.eval(feed_dict={self._encoded_f: enc})
- with self.test_session():
+ with self.cached_session():
# Invalid length.
msg = np.random.bytes(99)
enc = base64.urlsafe_b64encode(msg)
diff --git a/tensorflow/python/kernel_tests/basic_gpu_test.py b/tensorflow/python/kernel_tests/basic_gpu_test.py
index 987a6ffcd4..67e8618198 100644
--- a/tensorflow/python/kernel_tests/basic_gpu_test.py
+++ b/tensorflow/python/kernel_tests/basic_gpu_test.py
@@ -174,7 +174,7 @@ class BroadcastSimpleTest(test.TestCase):
numeric_gradient_type=None):
z = np_func(x, y)
zs = list(z.shape)
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
if x.dtype in (np.float32, np.float64):
@@ -195,7 +195,7 @@ class BroadcastSimpleTest(test.TestCase):
numeric_gradient_type=None):
z = np_func(x, y)
zs = list(z.shape)
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
if x.dtype in (np.float32, np.float64):
@@ -260,7 +260,7 @@ class GpuMultiSessionMemoryTest(test_util.TensorFlowTestCase):
threads = []
results = []
for _ in xrange(n_threads):
- session = self.test_session(graph=ops.Graph(), use_gpu=True)
+ session = self.session(graph=ops.Graph(), use_gpu=True)
results.append(set())
args = (session, results[-1])
threads.append(threading.Thread(target=self._run_session, args=args))
diff --git a/tensorflow/python/kernel_tests/batch_gather_op_test.py b/tensorflow/python/kernel_tests/batch_gather_op_test.py
index 8e7ae89f9d..7dd347989a 100644
--- a/tensorflow/python/kernel_tests/batch_gather_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_gather_op_test.py
@@ -86,7 +86,7 @@ class GatherTest(test.TestCase):
def testString(self):
params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
- with self.test_session():
+ with self.cached_session():
indices_tf = constant_op.constant([1])
self.assertAllEqual([[b"qwer", b"uiop"]],
array_ops.batch_gather(params, indices_tf).eval())
diff --git a/tensorflow/python/kernel_tests/batchtospace_op_test.py b/tensorflow/python/kernel_tests/batchtospace_op_test.py
index 6143cd3baa..03f3f64353 100644
--- a/tensorflow/python/kernel_tests/batchtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/batchtospace_op_test.py
@@ -60,7 +60,7 @@ class BatchToSpaceDepthToSpace(test.TestCase, PythonOpImpl):
array_ops.depth_to_space(
array_ops.transpose(x, [3, 1, 2, 0]), block_size=block_size),
[3, 1, 2, 0])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(y1.eval(), y2.eval())
@@ -235,7 +235,7 @@ class BatchToSpaceGradientTest(test.TestCase, PythonOpImpl):
# Check the gradients.
def _checkGrad(self, x, crops, block_size):
assert 4 == x.ndim
- with self.test_session():
+ with self.cached_session():
tf_x = ops.convert_to_tensor(x)
tf_y = self.batch_to_space(tf_x, crops, block_size)
epsilon = 1e-5
@@ -293,7 +293,7 @@ class BatchToSpaceNDGradientTest(test.TestCase):
block_shape = np.array(block_shape)
crops = constant_op.constant(
np.array(crops).reshape((len(block_shape), 2)), crops_dtype)
- with self.test_session():
+ with self.cached_session():
tf_x = ops.convert_to_tensor(x)
tf_y = array_ops.batch_to_space_nd(tf_x, block_shape, crops)
epsilon = 1e-5
diff --git a/tensorflow/python/kernel_tests/bcast_ops_test.py b/tensorflow/python/kernel_tests/bcast_ops_test.py
index 3305e55c05..3ec820aead 100644
--- a/tensorflow/python/kernel_tests/bcast_ops_test.py
+++ b/tensorflow/python/kernel_tests/bcast_ops_test.py
@@ -28,11 +28,11 @@ from tensorflow.python.platform import test
class BcastOpsTest(test.TestCase):
def _GetBroadcastShape(self, xs, ys):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(broadcast_args(xs, ys))
def _GetGradientArgs(self, xs, ys):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(broadcast_gradient_args(xs, ys))
def testBasic(self):
diff --git a/tensorflow/python/kernel_tests/betainc_op_test.py b/tensorflow/python/kernel_tests/betainc_op_test.py
index 16fdedac41..92d21462d5 100644
--- a/tensorflow/python/kernel_tests/betainc_op_test.py
+++ b/tensorflow/python/kernel_tests/betainc_op_test.py
@@ -47,7 +47,7 @@ class BetaincTest(test.TestCase):
tf_b_s = constant_op.constant(b_s, dtype=dtype)
tf_x_s = constant_op.constant(x_s, dtype=dtype)
tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s)
- with self.test_session():
+ with self.cached_session():
tf_out = tf_out_t.eval()
scipy_out = special.betainc(a_s, b_s, x_s).astype(np_dt)
@@ -60,13 +60,13 @@ class BetaincTest(test.TestCase):
# Test out-of-range values (most should return nan output)
combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3))
a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt)
- with self.test_session():
+ with self.cached_session():
tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
scipy_comb = special.betainc(a_comb, b_comb, x_comb).astype(np_dt)
self.assertAllCloseAccordingToType(scipy_comb, tf_comb)
# Test broadcasting between scalars and other shapes
- with self.test_session():
+ with self.cached_session():
self.assertAllCloseAccordingToType(
special.betainc(0.1, b_s, x_s).astype(np_dt),
math_ops.betainc(0.1, b_s, x_s).eval(),
@@ -96,7 +96,7 @@ class BetaincTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "must be equal"):
math_ops.betainc(0.5, [0.5], [[0.5]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Shapes of .* are inconsistent"):
a_p = array_ops.placeholder(dtype)
b_p = array_ops.placeholder(dtype)
@@ -140,7 +140,7 @@ class BetaincTest(test.TestCase):
self._testBetaInc(a_s, b_s, x_s, dtypes.float32)
def testBetaIncFpropAndBpropAreNeverNAN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
space = np.logspace(-8, 5).tolist()
space_x = np.linspace(1e-16, 1 - 1e-16).tolist()
ga_s, gb_s, gx_s = zip(*list(itertools.product(space, space, space_x)))
@@ -161,7 +161,7 @@ class BetaincTest(test.TestCase):
def testBetaIncGrads(self):
err_tolerance = 1e-3
- with self.test_session():
+ with self.cached_session():
# Test gradient
ga_s = np.abs(np.random.randn(2, 2) * 30) # in (0, infty)
gb_s = np.abs(np.random.randn(2, 2) * 30) # in (0, infty)
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py
index 2767df127e..8a58b3f97e 100644
--- a/tensorflow/python/kernel_tests/bincount_op_test.py
+++ b/tensorflow/python/kernel_tests/bincount_op_test.py
@@ -93,7 +93,7 @@ class BincountTest(test_util.TensorFlowTestCase):
def test_negative(self):
# unsorted_segment_sum will only report InvalidArgumentError on CPU
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors.InvalidArgumentError):
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/BUILD b/tensorflow/python/kernel_tests/boosted_trees/BUILD
index 4f92ab0795..20446781f0 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/BUILD
+++ b/tensorflow/python/kernel_tests/boosted_trees/BUILD
@@ -74,3 +74,16 @@ tf_py_test(
"//tensorflow/python:resources",
],
)
+
+tf_py_test(
+ name = "quantile_ops_test",
+ size = "small",
+ srcs = ["quantile_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/core/kernels/boosted_trees:boosted_trees_proto_py",
+ "//tensorflow/python:boosted_trees_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:resources",
+ ],
+)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
index 4e31b1ea2a..7cdc67f83f 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/prediction_ops_test.py
@@ -30,7 +30,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionOnEmptyEnsemble(self):
"""Tests that prediction on a dummy ensemble does not fail."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create a dummy ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto='')
@@ -63,7 +63,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testNoCachedPredictionButTreeExists(self):
"""Tests that predictions are updated once trees are added."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -129,7 +129,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionIsCurrent(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -201,7 +201,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionFromTheSameTree(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -315,7 +315,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionFromPreviousTree(self):
"""Tests the predictions work when we have cache from previous trees."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -445,9 +445,81 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
# change= 0.1(1.14+7.0-7.0)
self.assertAllClose([[1], [0.114]], logits_updates)
+ def testCategoricalSplits(self):
+ """Tests the training prediction work for categorical splits."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ categorical_split {
+ feature_id: 1
+ value: 2
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ categorical_split {
+ feature_id: 0
+ value: 13
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_metadata {
+ is_finalized: true
+ }
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [13, 1, 3]
+ feature_1_values = [2, 2, 1]
+
+ # No previous cached values.
+ cached_tree_ids = [0, 0, 0]
+ cached_node_ids = [0, 0, 0]
+
+ # Grow tree ensemble.
+ predict_op = boosted_trees_ops.training_predict(
+ tree_ensemble_handle,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits_updates, new_tree_ids, new_node_ids = session.run(predict_op)
+
+ self.assertAllClose([0, 0, 0], new_tree_ids)
+ self.assertAllClose([3, 4, 2], new_node_ids)
+ self.assertAllClose([[5.], [6.], [7.]], logits_updates)
+
def testCachedPredictionFromTheSameTreeWithPostPrunedNodes(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -577,7 +649,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionFromThePreviousTreeWithPostPrunedNodes(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -722,7 +794,7 @@ class TrainingPredictionOpsTest(test_util.TensorFlowTestCase):
def testCachedPredictionTheWholeTreeWasPruned(self):
"""Tests that prediction based on previous node in the tree works."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -794,7 +866,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
def testPredictionOnEmptyEnsemble(self):
"""Tests that prediction on a empty ensemble does not fail."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create an empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble(
'ensemble', serialized_proto='')
@@ -816,7 +888,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
def testPredictionMultipleTree(self):
"""Tests the predictions work when we have multiple trees."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -924,13 +996,232 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
logits = session.run(predict_op)
self.assertAllClose(expected_logits, logits)
+ def testCategoricalSplits(self):
+ """Tests the predictions work for categorical splits."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ categorical_split {
+ feature_id: 1
+ value: 2
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ categorical_split {
+ feature_id: 0
+ value: 13
+ left_id: 3
+ right_id: 4
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ """, tree_ensemble_config)
+
+ # Create existing ensemble with one root split
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [13, 1, 3]
+ feature_1_values = [2, 2, 1]
+
+ expected_logits = [[5.], [6.], [7.]]
+
+ # Prediction should work fine.
+ predict_op = boosted_trees_ops.predict(
+ tree_ensemble_handle,
+ bucketized_features=[feature_0_values, feature_1_values],
+ logits_dimension=1)
+
+ logits = session.run(predict_op)
+ self.assertAllClose(expected_logits, logits)
+
class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
"""Tests feature contribs ops for model understanding."""
+ def testContribsForOnlyABiasNode(self):
+ """Tests case when, after training, only left with a bias node.
+
+ For example, this could happen if the final ensemble contains one tree that
+ got pruned up to the root.
+ """
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ leaf {
+ scalar: 1.72
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata: {
+ num_layers_grown: 0
+ }
+ """, tree_ensemble_config)
+
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # All features are unused.
+ feature_0_values = [36, 32]
+ feature_1_values = [13, -29]
+ feature_2_values = [11, 27]
+
+ # Expected logits are computed by traversing the logit path and
+ # subtracting child logits from parent logits.
+ bias = 1.72 * 0.1 # Root node of tree_0.
+ expected_feature_ids = ((), ())
+ expected_logits_paths = ((bias,), (bias,))
+
+ bucketized_features = [
+ feature_0_values, feature_1_values, feature_2_values
+ ]
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble_handle,
+ bucketized_features=bucketized_features,
+ logits_dimension=1)
+
+ serialized_examples_debug_outputs = session.run(debug_op)
+ feature_ids = []
+ logits_paths = []
+ for example in serialized_examples_debug_outputs:
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example)
+ feature_ids.append(example_debug_outputs.feature_ids)
+ logits_paths.append(example_debug_outputs.logits_path)
+
+ self.assertAllClose(feature_ids, expected_feature_ids)
+ self.assertAllClose(logits_paths, expected_logits_paths)
+
+ def testContribsMultipleTreeWhenFirstTreeIsABiasNode(self):
+ """Tests case when, after training, first tree contains only a bias node."""
+ with self.cached_session() as session:
+ tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ leaf {
+ scalar: 1.72
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ threshold: 26
+ left_id: 1
+ right_id: 2
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ threshold: 50
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ original_leaf: {scalar: 5.5}
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 7.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 6.0
+ }
+ }
+ }
+ tree_weights: 1.
+ tree_weights: 0.1
+ tree_metadata: {
+ num_layers_grown: 0
+ }
+ tree_metadata: {
+ num_layers_grown: 1
+ }
+ """, tree_ensemble_config)
+
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ 'ensemble', serialized_proto=tree_ensemble_config.SerializeToString())
+ tree_ensemble_handle = tree_ensemble.resource_handle
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ feature_0_values = [36, 32]
+ feature_1_values = [13, -29] # Unused feature.
+ feature_2_values = [11, 27]
+
+ # Expected logits are computed by traversing the logit path and
+ # subtracting child logits from parent logits.
+ expected_feature_ids = ((2, 0), (2,))
+ # bias = 1.72 * 1. # Root node of tree_0.
+ # example_0 : (bias, 0.1 * 5.5 + bias, 0.1 * 5. + bias)
+ # example_1 : (bias, 0.1 * 7. + bias )
+ expected_logits_paths = ((1.72, 2.27, 2.22), (1.72, 2.42))
+
+ bucketized_features = [
+ feature_0_values, feature_1_values, feature_2_values
+ ]
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble_handle,
+ bucketized_features=bucketized_features,
+ logits_dimension=1)
+
+ serialized_examples_debug_outputs = session.run(debug_op)
+ feature_ids = []
+ logits_paths = []
+ for example in serialized_examples_debug_outputs:
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example)
+ feature_ids.append(example_debug_outputs.feature_ids)
+ logits_paths.append(example_debug_outputs.logits_path)
+
+ self.assertAllClose(feature_ids, expected_feature_ids)
+ self.assertAllClose(logits_paths, expected_logits_paths)
+
def testContribsMultipleTree(self):
"""Tests that the contribs work when we have multiple trees."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -1018,11 +1309,14 @@ class FeatureContribsOpsTest(test_util.TensorFlowTestCase):
tree_weights: 0.2
tree_weights: 1.0
tree_metadata: {
- num_layers_grown: 1}
+ num_layers_grown: 1
+ }
tree_metadata: {
- num_layers_grown: 2}
+ num_layers_grown: 2
+ }
tree_metadata: {
- num_layers_grown: 1}
+ num_layers_grown: 1
+ }
""", tree_ensemble_config)
tree_ensemble = boosted_trees_ops.TreeEnsemble(
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
new file mode 100644
index 0000000000..e0d46bae83
--- /dev/null
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test for checking quantile related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import boosted_trees_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as resource_handle_op
+from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as resource_initialized
+from tensorflow.python.platform import googletest
+
+
+class QuantileOpsTest(test_util.TensorFlowTestCase):
+
+ def create_resource(self, name, eps, max_elements, num_streams=1):
+ quantile_accumulator_handle = resource_handle_op(
+ container="", shared_name=name, name=name)
+ create_op = boosted_trees_ops.create_quantile_stream_resource(
+ quantile_accumulator_handle,
+ epsilon=eps,
+ max_elements=max_elements,
+ num_streams=num_streams)
+ is_initialized_op = resource_initialized(quantile_accumulator_handle)
+ resources.register_resource(quantile_accumulator_handle, create_op,
+ is_initialized_op)
+ return quantile_accumulator_handle
+
+ def setUp(self):
+ """Sets up the quantile ops test as follows.
+
+ Create a batch of 6 examples having 2 features
+ The data looks like this
+ | Instance | instance weights | Feature 0 | Feature 1
+ | 0 | 10 | 1.2 | 2.3
+ | 1 | 1 | 12.1 | 1.2
+ | 2 | 1 | 0.3 | 1.1
+ | 3 | 1 | 0.5 | 2.6
+ | 4 | 1 | 0.6 | 3.2
+ | 5 | 1 | 2.2 | 0.8
+ """
+
+ self._feature_0 = constant_op.constant(
+ [[1.2], [12.1], [0.3], [0.5], [0.6], [2.2]], dtype=dtypes.float32)
+ self._feature_1 = constant_op.constant(
+ [[2.3], [1.2], [1.1], [2.6], [3.2], [0.8]], dtype=dtypes.float32)
+ self._feature_0_boundaries = constant_op.constant(
+ [0.3, 0.6, 1.2, 12.1], dtype=dtypes.float32)
+ self._feature_1_boundaries = constant_op.constant(
+ [0.8, 1.2, 2.3, 3.2], dtype=dtypes.float32)
+ self._feature_0_quantiles = constant_op.constant(
+ [[2], [3], [0], [1], [1], [3]], dtype=dtypes.int32)
+ self._feature_1_quantiles = constant_op.constant(
+ [[2], [1], [1], [3], [3], [0]], dtype=dtypes.int32)
+
+ self._example_weights = constant_op.constant(
+ [10, 1, 1, 1, 1, 1], dtype=dtypes.float32)
+
+ self.eps = 0.01
+ self.max_elements = 1 << 16
+ self.num_quantiles = constant_op.constant(3, dtype=dtypes.int64)
+
+ def testBasicQuantileBucketsSingleResource(self):
+ with self.cached_session() as sess:
+ quantile_accumulator_handle = self.create_resource("floats", self.eps,
+ self.max_elements, 2)
+ resources.initialize_resources(resources.shared_resources()).run()
+ summaries = boosted_trees_ops.make_quantile_summaries(
+ [self._feature_0, self._feature_1], self._example_weights,
+ epsilon=self.eps)
+ summary_op = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle, summaries)
+ flush_op = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle, self.num_quantiles)
+ buckets = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle, num_features=2)
+ quantiles = boosted_trees_ops.boosted_trees_bucketize(
+ [self._feature_0, self._feature_1], buckets)
+ sess.run(summary_op)
+ sess.run(flush_op)
+ self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
+ self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
+
+ self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+ self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+ def testBasicQuantileBucketsMultipleResources(self):
+ with self.cached_session() as sess:
+ quantile_accumulator_handle_0 = self.create_resource("float_0", self.eps,
+ self.max_elements)
+ quantile_accumulator_handle_1 = self.create_resource("float_1", self.eps,
+ self.max_elements)
+ resources.initialize_resources(resources.shared_resources()).run()
+ summaries = boosted_trees_ops.make_quantile_summaries(
+ [self._feature_0, self._feature_1], self._example_weights,
+ epsilon=self.eps)
+ summary_op_0 = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle_0,
+ [summaries[0]])
+ summary_op_1 = boosted_trees_ops.quantile_add_summaries(
+ quantile_accumulator_handle_1,
+ [summaries[1]])
+ flush_op_0 = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle_0, self.num_quantiles)
+ flush_op_1 = boosted_trees_ops.quantile_flush(
+ quantile_accumulator_handle_1, self.num_quantiles)
+ bucket_0 = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle_0, num_features=1)
+ bucket_1 = boosted_trees_ops.get_bucket_boundaries(
+ quantile_accumulator_handle_1, num_features=1)
+ quantiles = boosted_trees_ops.boosted_trees_bucketize(
+ [self._feature_0, self._feature_1], bucket_0 + bucket_1)
+ sess.run([summary_op_0, summary_op_1])
+ sess.run([flush_op_0, flush_op_1])
+ self.assertAllClose(self._feature_0_boundaries, bucket_0[0].eval())
+ self.assertAllClose(self._feature_1_boundaries, bucket_1[0].eval())
+
+ self.assertAllClose(self._feature_0_quantiles, quantiles[0].eval())
+ self.assertAllClose(self._feature_1_quantiles, quantiles[1].eval())
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
index d5f0c22d6e..65bb9ab55f 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/resource_ops_test.py
@@ -31,7 +31,7 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
"""Tests resource_ops."""
def testCreate(self):
- with self.test_session():
+ with self.cached_session():
ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
resources.initialize_resources(resources.shared_resources()).run()
stamp_token = ensemble.get_stamp_token()
@@ -44,7 +44,7 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([0, 1], nodes_range.eval())
def testCreateWithProto(self):
- with self.test_session():
+ with self.cached_session():
ensemble_proto = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -161,7 +161,7 @@ class ResourceOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([16, 19], nodes_range.eval())
def testSerializeDeserialize(self):
- with self.test_session():
+ with self.cached_session():
# Initialize.
ensemble = boosted_trees_ops.TreeEnsemble('ensemble', stamp_token=5)
resources.initialize_resources(resources.shared_resources()).run()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index 568e695fd5..09e9cfa3af 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -30,7 +30,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithoutRegularization(self):
"""Testing Gain calculation without any regularization."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -78,7 +78,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithL2(self):
"""Testing Gain calculation with L2."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -126,7 +126,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithL1(self):
"""Testing Gain calculation with L1."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -177,7 +177,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithTreeComplexity(self):
"""Testing Gain calculation with L2."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -229,7 +229,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithMinNodeWeight(self):
"""Testing Gain calculation without any regularization."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -276,7 +276,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self):
"""Testing Gain calculation without any regularization."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
max_splits = 7
node_id_range = [1, 3] # node 1 through 2 will be processed.
stats_summary_list = [
@@ -329,7 +329,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testMakeStatsSummarySimple(self):
"""Simple test for MakeStatsSummary."""
- with self.test_session():
+ with self.cached_session():
self.assertAllClose([[[[1., 5.], [2., 6.]], [[3., 7.], [4., 8.]]]],
boosted_trees_ops.make_stats_summary(
node_ids=[0, 0, 1, 1],
@@ -341,7 +341,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testMakeStatsSummaryAccumulate(self):
"""Tests that Summary actually accumulates."""
- with self.test_session():
+ with self.cached_session():
max_splits = 3
num_buckets = 4
node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
@@ -363,7 +363,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
def testMakeStatsSummaryMultipleFeatures(self):
"""Tests that MakeStatsSummary works for multiple features."""
- with self.test_session():
+ with self.cached_session():
max_splits = 3
num_buckets = 4
node_ids = [1, 1, 2, 2, 1, 1, 2, 0]
@@ -392,7 +392,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
result.eval())
def _verify_precision(self, length):
- with self.test_session():
+ with self.cached_session():
max_splits = 1
num_buckets = 1
node_ids = array_ops.fill([length], 0)
diff --git a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
index d55240297a..ea022820e4 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/training_ops_test.py
@@ -32,7 +32,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowWithEmptyEnsemble(self):
"""Test growing an empty ensemble."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
@@ -141,7 +141,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testBiasCenteringOnEmptyEnsemble(self):
"""Test growing with bias centering on an empty ensemble."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble = boosted_trees_ops.TreeEnsemble('ensemble')
tree_ensemble_handle = tree_ensemble.resource_handle
@@ -184,7 +184,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowExistingEnsembleTreeNotFinalized(self):
"""Test growing an existing ensemble with the last tree not finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -368,7 +368,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowExistingEnsembleTreeFinalized(self):
"""Test growing an existing ensemble with the last tree finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -517,7 +517,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPrePruning(self):
"""Test growing an existing ensemble with pre-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge("""
trees {
@@ -673,7 +673,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testMetadataWhenCantSplitDueToEmptySplits(self):
"""Test that the metadata is updated even though we can't split."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -784,7 +784,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testMetadataWhenCantSplitDuePrePruning(self):
"""Test metadata is updated correctly when no split due to prepruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
text_format.Merge(
"""
@@ -919,7 +919,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfSomeNodes(self):
"""Test growing an ensemble with post-pruning."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
@@ -1253,7 +1253,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningOfAllNodes(self):
"""Test growing an ensemble with post-pruning, with all nodes are pruned."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
@@ -1436,7 +1436,7 @@ class UpdateTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testPostPruningChangesNothing(self):
"""Test growing an ensemble with post-pruning with all gains >0."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = boosted_trees_pb2.TreeEnsemble()
tree_ensemble = boosted_trees_ops.TreeEnsemble(
diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
index bd2339f31d..09c325f2bc 100644
--- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
+++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py
@@ -90,7 +90,7 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
x = constant_op.constant(1, dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [2, 4, 3])
out = 2 * v
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
@@ -100,7 +100,7 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [2, 5, 3])
out = 2 * v
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
@@ -110,7 +110,7 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [5, 2, 3])
out = 2 * v
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
@@ -119,7 +119,7 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32)
v = array_ops.broadcast_to(x, [5, 4, 6])
out = 2 * v
- with self.test_session():
+ with self.cached_session():
err = gradient_checker.compute_gradient_error(x, x.get_shape(),
out, out.get_shape())
self.assertLess(err, 1e-4)
diff --git a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
index 28b3dc45e9..b19077db56 100644
--- a/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
+++ b/tensorflow/python/kernel_tests/candidate_sampler_ops_test.py
@@ -38,7 +38,7 @@ class RangeSamplerOpsTest(test.TestCase):
TRUE_LABELS = [[1, 2], [0, 4], [3, 3]]
def testTrueCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
indices = constant_op.constant([0, 0, 1, 1, 2, 2])
true_candidates_vec = constant_op.constant([1, 2, 0, 4, 3, 3])
true_candidates_matrix = array_ops.reshape(
@@ -50,7 +50,7 @@ class RangeSamplerOpsTest(test.TestCase):
self.assertAllEqual(true_candidates_val, self.TRUE_LABELS)
def testSampledCandidates(self):
- with self.test_session():
+ with self.cached_session():
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -62,7 +62,7 @@ class RangeSamplerOpsTest(test.TestCase):
self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
def testTrueLogExpectedCount(self):
- with self.test_session():
+ with self.cached_session():
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
_, true_expected_count, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -77,7 +77,7 @@ class RangeSamplerOpsTest(test.TestCase):
[self.BATCH_SIZE, self.NUM_TRUE])
def testSampledLogExpectedCount(self):
- with self.test_session():
+ with self.cached_session():
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
_, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler( # pylint: disable=line-too-long
@@ -90,7 +90,7 @@ class RangeSamplerOpsTest(test.TestCase):
self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
def testAccidentalHits(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
@@ -109,7 +109,7 @@ class RangeSamplerOpsTest(test.TestCase):
def testSeed(self):
def draw(seed):
- with self.test_session():
+ with self.cached_session():
true_classes = constant_op.constant(
[[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
sampled, _, _ = candidate_sampling_ops.log_uniform_candidate_sampler(
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index 214d5cb3c0..c90520e46d 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -174,7 +174,7 @@ class CastOpTest(test.TestCase):
self.assertAllEqual(np.isnan(self._cast(np.nan, np.float64, True)), True)
def _OpError(self, x, dtype, err):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(err):
math_ops.cast(x, dtype).eval()
@@ -182,7 +182,7 @@ class CastOpTest(test.TestCase):
self._OpError(np.arange(0, 10), dtypes.string, "Cast.*int64.*string.*")
def testCastToTypeOfVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variables.Variable(5, dtype=dtypes.float32)
y = variables.Variable(True, dtype=dtypes.bool)
cast = math_ops.cast(y, x.dtype)
@@ -193,7 +193,7 @@ class CastOpTest(test.TestCase):
t = [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
for src_t in t:
for dst_t in t:
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1.0, src_t)
z = array_ops.identity(x)
y = math_ops.cast(z, dst_t)
@@ -209,7 +209,7 @@ class SparseTensorCastTest(test.TestCase):
shape = constant_op.constant([3], dtypes.int64)
st = sparse_tensor.SparseTensor(indices, values, shape)
st_cast = math_ops.cast(st, dtypes.float32)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(st_cast.indices.eval(), [[0], [1], [2]])
self.assertAllEqual(st_cast.values.eval(),
np.array([1, 2, 3], np.float32))
@@ -221,7 +221,7 @@ class SaturateCastTest(test.TestCase):
def testSaturate(self):
in_types = dtypes.float32,
out_types = dtypes.int8, dtypes.uint8, dtypes.int16, dtypes.float32
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for in_type in in_types:
for out_type in out_types:
lo, hi = in_type.min, in_type.max
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index 27a674e223..bd4011d58e 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -785,7 +785,7 @@ class EnsureShapeTest(test.TestCase):
derived = math_ops.divide(placeholder, 3, name="MyDivide")
derived = check_ops.ensure_shape(derived, (3, 3, 3))
feed_val = [[1], [2]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
r"Shape of tensor MyDivide \[2,1\] is not compatible with "
@@ -797,7 +797,7 @@ class EnsureShapeTest(test.TestCase):
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (None, None, 3))
feed_val = [[1], [2]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with "
@@ -809,7 +809,7 @@ class EnsureShapeTest(test.TestCase):
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (2, 1))
feed_val = [[1], [2]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(derived, feed_dict={placeholder: feed_val})
def testEnsuresDynamicShape_WithUnknownDims(self):
@@ -817,7 +817,7 @@ class EnsureShapeTest(test.TestCase):
derived = placeholder / 3
derived = check_ops.ensure_shape(derived, (None, None))
feed_val = [[1], [2]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(derived, feed_dict={placeholder: feed_val})
def testGradient(self):
@@ -826,7 +826,7 @@ class EnsureShapeTest(test.TestCase):
gradient = gradients.gradients(derived, placeholder)
feed_val = [[4.0], [-1.0]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradient_values, = sess.run(gradient, feed_dict={placeholder: feed_val})
expected = [[1.0], [1.0]]
diff --git a/tensorflow/python/kernel_tests/checkpoint_ops_test.py b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
index 7f147ba53a..51611b75af 100644
--- a/tensorflow/python/kernel_tests/checkpoint_ops_test.py
+++ b/tensorflow/python/kernel_tests/checkpoint_ops_test.py
@@ -57,7 +57,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=0)
expected_remapping = range(0, 3)
expected_num_present = 3
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -70,7 +70,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=0)
expected_remapping = [2, 0, 1]
expected_num_present = 3
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -83,7 +83,7 @@ class GenerateVocabRemappingTest(test.TestCase):
new_vocab_offset=1)
expected_remapping = [0]
expected_num_present = 1
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -98,7 +98,7 @@ class GenerateVocabRemappingTest(test.TestCase):
old_vocab_size=2)
expected_remapping = [-1, 0, 1]
expected_num_present = 2
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_remapping, remapping.eval())
self.assertAllEqual(expected_num_present, num_present.eval())
@@ -122,7 +122,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
self.old_tensor_name = 'some_scope/matrix'
save = saver.Saver([matrix])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
self.bundle_file = os.path.join(test.get_temp_dir(), 'bundle_checkpoint')
save.save(sess, self.bundle_file)
@@ -140,7 +140,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=2,
num_cols=self.old_num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping],
remapped_matrix.eval())
@@ -155,7 +155,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(row_remapping),
num_cols=len(col_remapping))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
remapped_matrix.eval())
@@ -170,7 +170,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(row_remapping),
num_cols=len(col_remapping))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(self.matrix_value[row_remapping][:, col_remapping],
remapped_matrix.eval())
@@ -189,7 +189,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
expected_remapped_matrix = np.reshape(
[33, init_val, init_val, init_val, 1, init_val], [3, 2])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_remapped_matrix, remapped_matrix.eval())
def test_load_and_remap_all_missing_rows(self):
@@ -204,7 +204,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=initializing_values,
num_rows=num_rows,
num_cols=self.old_num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, self.old_num_cols)),
remapped_matrix.eval())
@@ -222,7 +222,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=initializing_values,
num_rows=num_rows,
num_cols=num_cols)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
np.reshape(initializing_values, (num_rows, num_cols)),
remapped_matrix.eval())
@@ -243,7 +243,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=len(invalid_remapping),
num_cols=self.old_num_cols)
- with self.test_session(), self.assertRaises(errors.UnimplementedError):
+ with self.cached_session(), self.assertRaises(errors.UnimplementedError):
remapped_matrix.eval()
# Invalid column remapping.
@@ -255,7 +255,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=self.old_num_rows,
num_cols=len(invalid_remapping))
- with self.test_session(), self.assertRaises(errors.UnimplementedError):
+ with self.cached_session(), self.assertRaises(errors.UnimplementedError):
remapped_matrix.eval()
def test_load_and_remap_incorrect_initializing_values(self):
@@ -272,7 +272,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[],
num_rows=3,
num_cols=2)
- with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+ with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
remapped_matrix.eval()
remapped_matrix = gen_checkpoint_ops.load_and_remap_matrix(
@@ -284,7 +284,7 @@ class LoadAndRemapMatrixTest(test.TestCase):
initializing_values=[0] * 5,
num_rows=3,
num_cols=2)
- with self.test_session(), self.assertRaises(errors.InvalidArgumentError):
+ with self.cached_session(), self.assertRaises(errors.InvalidArgumentError):
remapped_matrix.eval()
@@ -306,7 +306,7 @@ class LoadAndRemapMatrixWithMaxRowsTest(test.TestCase):
initializer=constant_op.constant(np_value, dtype=dtypes.float32),
partitioner=partitioner)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ckpt_path = os.path.join(test.get_temp_dir(), 'temp_ckpt')
save = saver.Saver([matrix])
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index de52a70cc0..bb7b645da2 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -39,7 +39,7 @@ class ClipTest(test.TestCase):
min_val = constant_op.constant([0.5, 0.5, 0.5, 0.5], dtype=dtypes.float32)
max_val = constant_op.constant([3.5, 3.5, 3.5, 3.5], dtype=dtypes.float32)
outputs_2 = clip_ops.clip_by_value(inputs, min_val, max_val)
- with self.test_session():
+ with self.cached_session():
error_1 = gradient_checker.compute_gradient_error(inputs, [4], outputs_1,
[4])
self.assertLess(error_1, 1e-4)
@@ -139,7 +139,7 @@ class ClipTest(test.TestCase):
def testClipByValueNonFinite(self):
# TODO(b/78016351): Enable test on GPU once the bug is fixed.
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([float('NaN'), float('Inf'), -float('Inf')])
np_ans = [float('NaN'), 4.0, -4.0]
clip_value = 4.0
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index c22934ce47..0e59ce6972 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -383,7 +383,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
output = array_ops.concat(xs, 0)
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
@@ -397,7 +397,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
output = array_ops.concat(xs, 1)
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
@@ -411,7 +411,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 0)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -426,7 +426,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 1)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -441,7 +441,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 2)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -452,7 +452,7 @@ class ConcatOpTest(test.TestCase):
def testIndexedSlicesConcatDim1Grad_UnknownInputDim(self):
x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
output_shape = [4, 11, 3]
- with self.test_session():
+ with self.cached_session():
x_1 = array_ops.placeholder(dtypes.float64)
x_2 = array_ops.placeholder(dtypes.float64)
x_3 = array_ops.placeholder(dtypes.float64)
@@ -473,13 +473,13 @@ class ConcatOpTest(test.TestCase):
def testConcatTuple(self):
c1 = np.random.rand(4, 4)
c2 = np.random.rand(4, 4)
- with self.test_session():
+ with self.cached_session():
concat_list_t = array_ops.concat([c1, c2], 0)
concat_tuple_t = array_ops.concat((c1, c2), 0)
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
def testConcatNoScalars(self):
- with self.test_session():
+ with self.cached_session():
scalar = constant_op.constant(7)
dim = array_ops.placeholder(dtypes.int32)
with self.assertRaisesRegexp(
@@ -554,7 +554,7 @@ class ConcatOpTest(test.TestCase):
def _testGradientsForAxis(
self, inp_tensors, axis, output_shape, feed_dict=None):
- with self.test_session():
+ with self.cached_session():
c = array_ops.concat(inp_tensors, axis)
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
@@ -566,7 +566,7 @@ class ConcatOpTest(test.TestCase):
def _testIndexedSlicesGradientsForAxis(
self, inp_tensors, axis, output_shape, gather_indexes, feed_dict=None):
- with self.test_session():
+ with self.cached_session():
c = array_ops.gather(
array_ops.concat(inp_tensors, axis), gather_indexes)
grad_inp = np.random.rand(*output_shape).astype("f")
@@ -631,7 +631,7 @@ class ConcatOffsetTest(test.TestCase):
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
def testNotVector(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([[2, 3, 5]], dtypes.int32)
s1 = constant_op.constant([[2, 7, 5]], dtypes.int32)
@@ -641,7 +641,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testConcatDimOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(4, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@@ -651,7 +651,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testDimMismatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5, 10], dtypes.int32)
@@ -661,7 +661,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testSizeMismatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 10], dtypes.int32)
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 0dc3c53bc0..377c041675 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -41,7 +41,7 @@ class CondV2Test(test.TestCase):
def _testCond(self, true_fn, false_fn, train_vals, feed_dict=None):
if not feed_dict:
feed_dict = {}
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with self.session(graph=ops.get_default_graph()) as sess:
pred = array_ops.placeholder(dtypes.bool, name="pred")
expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected")
@@ -107,7 +107,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNoInputs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pred = array_ops.placeholder(dtypes.bool, name="pred")
def true_fn():
@@ -131,7 +131,7 @@ class CondV2Test(test.TestCase):
def false_fn():
return x + 1
- return cond_v2.cond_v2(pred, true_fn, false_fn, name=name)[0].op
+ return cond_v2.cond_v2(pred, true_fn, false_fn, name=name).op
def testDefaultName(self):
with ops.Graph().as_default():
@@ -382,7 +382,7 @@ class CondV2Test(test.TestCase):
with ops.Graph().as_default():
grads, pred_outer, pred_inner = build_graph()
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with self.session(graph=ops.get_default_graph()) as sess:
self.assertSequenceEqual(
sess.run(grads, {
pred_outer: True,
@@ -445,7 +445,7 @@ class CondV2Test(test.TestCase):
with ops.Graph().as_default():
grads, pred_outer, pred_inner = build_graph()
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with self.session(graph=ops.get_default_graph()) as sess:
self.assertSequenceEqual(
sess.run(grads, {
pred_outer: True,
@@ -504,7 +504,7 @@ class CondV2Test(test.TestCase):
with ops.Graph().as_default():
grads, pred_outer, pred_inner = build_graph()
- with self.test_session(graph=ops.get_default_graph()) as sess:
+ with self.session(graph=ops.get_default_graph()) as sess:
self.assertSequenceEqual(
sess.run(grads, {
pred_outer: True,
@@ -527,7 +527,7 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testSecondDerivative(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
pred = array_ops.placeholder(dtypes.bool, name="pred")
x = constant_op.constant(3.0, name="x")
@@ -569,12 +569,11 @@ class CondV2Test(test.TestCase):
ops.add_to_collection("pred", pred)
cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
- for c in cond:
- ops.add_to_collection("cond", c)
+ ops.add_to_collection("cond", cond)
meta_graph = saver.export_meta_graph()
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
saver.import_meta_graph(meta_graph)
x = ops.get_collection("x")[0]
pred = ops.get_collection("pred")[0]
@@ -598,7 +597,7 @@ class CondV2Test(test.TestCase):
def testLowering(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
out_cond = self._createCond("cond")
run_options = config_pb2.RunOptions(output_partition_graphs=True)
@@ -624,7 +623,7 @@ class CondV2Test(test.TestCase):
"An `If` op was found, but it should be lowered.")
def testLoweringDisabledInXLA(self):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
# Build the cond_v2 in an XLA context
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
@@ -661,7 +660,7 @@ class CondV2CollectionTest(test.TestCase):
def testCollectionIntValueAccessInCond(self):
"""Read values from graph collections inside of cond_v2."""
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
x = 2
y = 5
ops.add_to_collection("x", x)
@@ -672,12 +671,12 @@ class CondV2CollectionTest(test.TestCase):
return math_ops.add(x_const, y_const)
cnd = cond_v2.cond_v2(True, fn, fn)
- self.assertEquals(cnd[0].eval(), 7)
+ self.assertEquals(cnd.eval(), 7)
def testCollectionTensorValueAccessInCond(self):
"""Read tensors from collections inside of cond_v2 & use them."""
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
x = constant_op.constant(2)
y = constant_op.constant(5)
ops.add_to_collection("x", x)
@@ -689,12 +688,12 @@ class CondV2CollectionTest(test.TestCase):
return math_ops.add(x_read, y_read)
cnd = cond_v2.cond_v2(math_ops.less(x, y), fn, fn)
- self.assertEquals(cnd[0].eval(), 7)
+ self.assertEquals(cnd.eval(), 7)
def testCollectionIntValueWriteInCond(self):
"""Make sure Int writes to collections work inside of cond_v2."""
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
x = constant_op.constant(2)
y = constant_op.constant(5)
def true_fn():
@@ -709,7 +708,7 @@ class CondV2CollectionTest(test.TestCase):
cnd = cond_v2.cond_v2(
True, true_fn,
false_fn)
- self.assertEquals(cnd[0].eval(), 14)
+ self.assertEquals(cnd.eval(), 14)
read_z_collection = ops.get_collection("z")
self.assertEquals(read_z_collection, [7])
@@ -725,7 +724,7 @@ class CondV2ContainerTest(test.TestCase):
"""
self.skipTest("b/113048653")
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
v0 = variables.Variable([0])
q0 = data_flow_ops.FIFOQueue(1, dtypes.float32)
@@ -782,10 +781,10 @@ class CondV2ContainerTest(test.TestCase):
with ops.container("l1"):
cnd_true = cond_v2.cond_v2(True, true_fn, false_fn)
- self.assertEquals(cnd_true[0].eval(), 2)
+ self.assertEquals(cnd_true.eval(), 2)
cnd_false = cond_v2.cond_v2(False, true_fn, false_fn)
- self.assertEquals(cnd_false[0].eval(), 6)
+ self.assertEquals(cnd_false.eval(), 6)
v4 = variables.Variable([3])
q4 = data_flow_ops.FIFOQueue(1, dtypes.float32)
@@ -801,9 +800,8 @@ class CondV2ContainerTest(test.TestCase):
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def testColocateWithBeforeCond(self):
- self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
a = constant_op.constant([2.0], name="a")
b = constant_op.constant([2.0], name="b")
@@ -814,7 +812,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -823,12 +821,11 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(a.op):
with ops.colocate_with(b.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
def testColocateWithInAndOutOfCond(self):
- self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
a = constant_op.constant([2.0], name="a")
b = constant_op.constant([2.0], name="b")
@@ -840,7 +837,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.colocate_with(a.op):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
d = constant_op.constant([2.0], name="d")
self.assertEqual([b"loc:@a"], d.op.colocation_groups())
@@ -861,7 +858,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.colocate_with(b.op):
c = math_ops.add(a, a, name="c")
return c
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
@@ -874,16 +871,16 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceBeforeCond(self):
- self.skipTest("b/112166045")
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.session(graph=g):
+
def fn():
c = constant_op.constant(3.0)
self.assertEqual("/device:CPU:0", c.op.device)
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn, fn)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn, fn).eval(), 3)
def fn2():
c = constant_op.constant(3.0)
@@ -891,19 +888,21 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
return c
with ops.device("/device:GPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
def testDeviceInAndOutOfCond(self):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g):
+ with self.test_session(
+ graph=g, config=config_pb2.ConfigProto(device_count={"CPU": 2})):
+
def fn2():
- with ops.device("/device:GPU:0"):
+ with ops.device("/device:CPU:1"):
c = constant_op.constant(3.0)
- self.assertEqual("/device:GPU:0", c.op.device)
+ self.assertEqual("/device:CPU:1", c.op.device)
return c
with ops.device("/device:CPU:0"):
- self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
+ self.assertEquals(cond_v2.cond_v2(True, fn2, fn2).eval(), 3)
d = constant_op.constant(4.0)
self.assertEqual("/device:CPU:0", d.op.device)
@@ -922,7 +921,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
with ops.device("/device:CPU:0"):
a = constant_op.constant([2.0], name="a")
- out_cond_2 = cond_v2.cond_v2(True, fn, fn)[0]
+ out_cond_2 = cond_v2.cond_v2(True, fn, fn)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 86802664d1..97ab23fe49 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -80,26 +80,26 @@ class ConditionalAccumulatorTest(test.TestCase):
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q")
self.assertEqual(q.num_accumulated().eval(), 0)
def testAccumulatorSetGlobalStep(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
set_global_step_op = q.set_global_step(1)
set_global_step_op.run()
def testAccumulatorApplyGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
accum_op = q.apply_grad((10.0,))
accum_op.run()
def testDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64]
for i in range(len(dtypes)):
@@ -116,7 +116,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(sum(elems) / len(elems), result)
def testAccumulatorMultipleAccumulators(self):
- with self.test_session():
+ with self.cached_session():
q_f32_0 = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
q_f32_1 = data_flow_ops.ConditionalAccumulator(
@@ -135,7 +135,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(result, i + 10.0)
def testAccumulatorApplyAndTakeGradWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=(3, 2))
elems = [[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
@@ -166,7 +166,7 @@ class ConditionalAccumulatorTest(test.TestCase):
q.apply_grad([[1.0], [2.0], [3.0]])
def testAccumulatorDynamicShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=None)
@@ -191,7 +191,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertTrue(is_all_equal)
def testAccumulatorWrongDynamicShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=None)
@@ -209,7 +209,7 @@ class ConditionalAccumulatorTest(test.TestCase):
sess.run(accum_op, feed_dict={x: [[1.0], [2.0], [3.0]]})
def testAccumulatorSizeAfterApplyGrad(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
accum_op = q.apply_grad((10.0,))
@@ -220,7 +220,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(q.num_accumulated().eval(), 2)
def testAccumulatorSizeAfterApplyGradAndTakeGrad(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
accum_op = q.apply_grad((10.0,))
@@ -248,7 +248,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(q.num_accumulated().eval(), 0)
def testAccumulatorTakeGradMean(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0]
@@ -272,7 +272,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(15.0, val)
def testAccumulatorTakeGradSum(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32,
name="Q",
@@ -307,7 +307,7 @@ class ConditionalAccumulatorTest(test.TestCase):
reduction_type="Invalid")
def testAccumulatorInvalidTakeGrad(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0]
@@ -322,7 +322,7 @@ class ConditionalAccumulatorTest(test.TestCase):
takeg_t.eval()
def testAccumulatorRepeatedTakeGradMean(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -349,7 +349,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(elems_ave + 0.0, val)
def testAccumulatorRepeatedTakeGradSum(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32,
name="Q",
@@ -379,7 +379,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(elems_sum, val)
def testAccumulatorIncrementGlobalStep(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -395,7 +395,7 @@ class ConditionalAccumulatorTest(test.TestCase):
inc_global_step.eval()
def testAccumulatorSetGlobalStepPreventsAccumulation(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -416,7 +416,7 @@ class ConditionalAccumulatorTest(test.TestCase):
if x >= ls), val)
def testParallelApplyGrad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
@@ -441,7 +441,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertEqual(val, sum(elems) / len(elems))
def testParallelTakeGrad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [e for e in range(10)]
@@ -473,7 +473,7 @@ class ConditionalAccumulatorTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testAccumulatorApplyAndBlockingTake(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
@@ -506,7 +506,7 @@ class ConditionalAccumulatorTest(test.TestCase):
sess.run(takeg_op)
def testAccumulatorCancel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
takeg_t = q.take_grad(1)
diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py
index 93f5323c41..bc24345261 100644
--- a/tensorflow/python/kernel_tests/confusion_matrix_test.py
+++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py
@@ -37,7 +37,7 @@ class ConfusionMatrixTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testExample(self):
"""This is a test of the example provided in pydoc."""
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([
[0, 0, 0, 0, 0],
[0, 0, 1, 0, 0],
@@ -49,7 +49,7 @@ class ConfusionMatrixTest(test.TestCase):
def _testConfMatrix(self, labels, predictions, truth, weights=None,
num_classes=None):
- with self.test_session():
+ with self.cached_session():
dtype = predictions.dtype
ans = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtype, weights=weights,
@@ -78,7 +78,7 @@ class ConfusionMatrixTest(test.TestCase):
self._testBasic(dtype=np.int64)
def _testConfMatrixOnTensors(self, tf_dtype, np_dtype):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
m_neg = array_ops.placeholder(dtype=dtypes.float32)
m_pos = array_ops.placeholder(dtype=dtypes.float32)
s = array_ops.placeholder(dtype=dtypes.float32)
@@ -229,7 +229,7 @@ class ConfusionMatrixTest(test.TestCase):
def testOutputIsInt32(self):
labels = np.arange(2)
predictions = np.arange(2)
- with self.test_session():
+ with self.cached_session():
cm = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtypes.int32)
tf_cm = cm.eval()
@@ -238,7 +238,7 @@ class ConfusionMatrixTest(test.TestCase):
def testOutputIsInt64(self):
labels = np.arange(2)
predictions = np.arange(2)
- with self.test_session():
+ with self.cached_session():
cm = confusion_matrix.confusion_matrix(
labels, predictions, dtype=dtypes.int64)
tf_cm = cm.eval()
@@ -260,7 +260,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -285,7 +285,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -310,7 +310,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder, expected_rank_diff=0))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -336,7 +336,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
labels_placeholder, predictions_placeholder))
expected_label_values = np.reshape(label_values, newshape=(2, 3))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -362,7 +362,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
labels_placeholder, predictions_placeholder, expected_rank_diff=1))
expected_label_values = np.reshape(label_values, newshape=(2, 3))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_label_values, static_labels.eval())
self.assertAllEqual(prediction_values, static_predictions.eval())
feed_dict = {
@@ -388,7 +388,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
labels_placeholder, predictions_placeholder))
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(expected_prediction_values, static_predictions.eval())
feed_dict = {
@@ -415,7 +415,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
labels_placeholder, predictions_placeholder, expected_rank_diff=-1))
expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(label_values, static_labels.eval())
self.assertAllEqual(expected_prediction_values, static_predictions.eval())
feed_dict = {
@@ -441,7 +441,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
- with self.test_session():
+ with self.cached_session():
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
@@ -466,7 +466,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase):
confusion_matrix.remove_squeezable_dimensions(
labels_placeholder, predictions_placeholder))
- with self.test_session():
+ with self.cached_session():
feed_dict = {
labels_placeholder: label_values,
predictions_placeholder: prediction_values
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 107ee37fab..d1e4e5477f 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -162,18 +162,18 @@ class ConstantTest(test.TestCase):
logging_const_op.run()
def testStringWithNulls(self):
- with self.test_session():
+ with self.cached_session():
val = ops.convert_to_tensor(b"\0\0\0\0").eval()
self.assertEqual(len(val), 4)
self.assertEqual(val, b"\0\0\0\0")
- with self.test_session():
+ with self.cached_session():
val = ops.convert_to_tensor(b"xx\0xx").eval()
self.assertEqual(len(val), 5)
self.assertAllEqual(val, b"xx\0xx")
nested = [[b"\0\0\0\0", b"xx\0xx"], [b"\0_\0_\0_\0", b"\0"]]
- with self.test_session():
+ with self.cached_session():
val = ops.convert_to_tensor(nested).eval()
# NOTE(mrry): Do not use assertAllEqual, because it converts nested to a
# numpy array, which loses the null terminators.
@@ -279,7 +279,7 @@ class AsTensorTest(test.TestCase):
self.assertTrue(isinstance(x, ops.Tensor))
def testAsTensorForShapeInput(self):
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor(tensor_shape.TensorShape([]))
self.assertEqual(dtypes_lib.int32, x.dtype)
self.assertAllEqual([], x.eval())
@@ -331,7 +331,7 @@ class AsTensorTest(test.TestCase):
tensor_shape.TensorShape([1, 2, 3]), dtype=dtypes_lib.float32)
def testAsTensorForDimensionInput(self):
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor(tensor_shape.TensorShape([1, 2, 3])[1])
self.assertEqual(dtypes_lib.int32, x.dtype)
self.assertAllEqual(2, x.eval())
@@ -367,7 +367,7 @@ class IdentityOpTest(test.TestCase):
class ZerosTest(test.TestCase):
def _Zeros(self, shape):
- with self.test_session():
+ with self.cached_session():
ret = array_ops.zeros(shape)
self.assertEqual(shape, ret.get_shape())
return ret.eval()
@@ -379,13 +379,13 @@ class ZerosTest(test.TestCase):
def testScalar(self):
self.assertEqual(0, self._Zeros([]))
self.assertEqual(0, self._Zeros(()))
- with self.test_session():
+ with self.cached_session():
scalar = array_ops.zeros(constant_op.constant([], dtype=dtypes_lib.int32))
self.assertEqual(0, scalar.eval())
def testDynamicSizes(self):
np_ans = np.array([[0] * 3] * 2)
- with self.test_session():
+ with self.cached_session():
# Creates a tensor of 2 x 3.
d = array_ops.fill([2, 3], 12., name="fill")
# Constructs a tensor of zeros of the same dimensions as "d".
@@ -396,7 +396,7 @@ class ZerosTest(test.TestCase):
self.assertShapeEqual(np_ans, z)
def testDtype(self):
- with self.test_session():
+ with self.cached_session():
d = array_ops.fill([2, 3], 12., name="fill")
self.assertEqual(d.get_shape(), [2, 3])
# Test default type for both constant size and dynamic size
@@ -489,7 +489,7 @@ class ZerosLikeTest(test.TestCase):
def testZerosLikeDtype(self):
# Make sure zeros_like works even for dtypes that cannot be cast between
- with self.test_session():
+ with self.cached_session():
shape = (3, 5)
dtypes = np.float32, np.complex64
for in_type in dtypes:
@@ -533,7 +533,7 @@ class ZerosLikeTest(test.TestCase):
class OnesTest(test.TestCase):
def _Ones(self, shape):
- with self.test_session():
+ with self.cached_session():
ret = array_ops.ones(shape)
self.assertEqual(shape, ret.get_shape())
return ret.eval()
@@ -544,13 +544,13 @@ class OnesTest(test.TestCase):
def testScalar(self):
self.assertEqual(1, self._Ones([]))
self.assertEqual(1, self._Ones(()))
- with self.test_session():
+ with self.cached_session():
scalar = array_ops.ones(constant_op.constant([], dtype=dtypes_lib.int32))
self.assertEqual(1, scalar.eval())
def testDynamicSizes(self):
np_ans = np.array([[1] * 3] * 2)
- with self.test_session():
+ with self.cached_session():
# Creates a tensor of 2 x 3.
d = array_ops.fill([2, 3], 12., name="fill")
# Constructs a tensor of ones of the same dimensions as "d".
@@ -561,7 +561,7 @@ class OnesTest(test.TestCase):
self.assertShapeEqual(np_ans, z)
def testAutoPack(self):
- with self.test_session():
+ with self.cached_session():
h = array_ops.placeholder(dtypes_lib.int32, shape=[])
w = array_ops.placeholder(dtypes_lib.int32, shape=[])
z = array_ops.ones([h, w])
@@ -569,7 +569,7 @@ class OnesTest(test.TestCase):
self.assertAllEqual(out, np.array([[1] * 16] * 4))
def testDtype(self):
- with self.test_session():
+ with self.cached_session():
d = array_ops.fill([2, 3], 12., name="fill")
self.assertEqual(d.get_shape(), [2, 3])
# Test default type for both constant size and dynamic size
@@ -606,7 +606,7 @@ class OnesLikeTest(test.TestCase):
dtypes_lib.complex128
]:
numpy_dtype = dtype.as_numpy_dtype
- with self.test_session():
+ with self.cached_session():
# Creates a tensor of non-zero values with shape 2 x 3.
d = constant_op.constant(
np.ones(
@@ -672,7 +672,7 @@ class FillTest(test.TestCase):
self.assertAllEqual(np_ans, tf_ans)
def testFillNegative(self):
- with self.test_session():
+ with self.cached_session():
for shape in (-1,), (2, -1), (-1, 2), (-2), (-3):
with self.assertRaises(ValueError):
array_ops.fill(shape, 7)
@@ -703,7 +703,7 @@ class FillTest(test.TestCase):
self.assertEqual([None, 17], f.get_shape().as_list())
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
in_v = constant_op.constant(5.0)
out_shape = [3, 2]
out_filled = array_ops.fill(out_shape, in_v)
@@ -715,7 +715,7 @@ class FillTest(test.TestCase):
class PlaceholderTest(test.TestCase):
def testDtype(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=(10, 10), name="p")
p_identity = array_ops.identity(p)
feed_array = np.random.rand(10, 10)
@@ -727,7 +727,7 @@ class PlaceholderTest(test.TestCase):
p_identity.eval()
def testShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=(10, 10), name="p")
p_identity = array_ops.identity(p)
feed_array = np.random.rand(10, 10)
@@ -744,7 +744,7 @@ class PlaceholderTest(test.TestCase):
p_identity.eval(feed_dict={p: feed_array[:5, :5]})
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=None, name="p")
p_identity = array_ops.identity(p)
# can feed anything
@@ -756,13 +756,13 @@ class PlaceholderTest(test.TestCase):
p_identity.eval(feed_dict={p: feed_array}), feed_array)
def testScalarShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=[], name="p")
p_identity = array_ops.identity(p)
self.assertAllClose(p_identity.eval(feed_dict={p: 5}), 5)
def testPartialShape(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=[None, 3], name="p")
p_identity = array_ops.identity(p)
feed_array = np.random.rand(10, 3)
@@ -774,7 +774,7 @@ class PlaceholderTest(test.TestCase):
p_identity.eval(feed_dict={p: feed_array[:5, :2]})
def testPartialShapeWhenNotFed(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.float32, shape=[None, 3], name="p")
p_identity = array_ops.identity(p)
@@ -784,7 +784,7 @@ class PlaceholderTest(test.TestCase):
p_identity.eval()
def testControlDependency(self):
- with self.test_session():
+ with self.cached_session():
p = array_ops.placeholder(dtypes_lib.int32, shape=[], name="p")
with ops.control_dependencies([p]):
c = constant_op.constant(5, dtypes_lib.int32)
@@ -872,7 +872,7 @@ versions {
"""
gdef = graph_pb2.GraphDef()
text_format.Merge(graph, gdef)
- with self.test_session():
+ with self.cached_session():
p, ret = importer.import_graph_def(
gdef, return_elements=["Placeholder:0", "add:0"])
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index eac97af4ed..fc4d2a3809 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -129,7 +129,7 @@ def isum(s, maximum_iterations=None):
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(7)
v = control_flow_ops._Identity(v)
@@ -141,7 +141,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(9, v2.eval())
def testRefEnter(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(7)
enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
@@ -154,7 +154,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(9, v3.eval())
def testRefSwitch(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(7)
p = constant_op.constant(True)
@@ -164,7 +164,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(9, v2.eval())
def testEnterMulExit(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
enter_data = gen_control_flow_ops.enter(data, "foo_1", False)
five = constant_op.constant(5)
@@ -176,7 +176,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
def testEnterShapePropagation(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
# If is_constant=True, the shape information should be propagated.
@@ -190,7 +190,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(enter_v_non_constant.shape, None)
def testSwitchMergeIndexedSlices(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([1, 2, 3, 4, 5, 6])
indices = constant_op.constant([0, 2, 4, 6, 8, 10])
data = ops.IndexedSlices(values, indices)
@@ -204,7 +204,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.arange(0, 12, 2), ind)
def testSwitchDeadBranch(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
ports = ops.convert_to_tensor(True, name="ports")
switch_op = control_flow_ops.switch(data, ports)
@@ -216,7 +216,7 @@ class ControlFlowTest(test.TestCase):
dead_branch.eval()
def testSwitchMergeLess(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
zero = ops.convert_to_tensor(0)
one = ops.convert_to_tensor(1)
@@ -228,7 +228,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.arange(1, 7), result)
def testSwitchMergeAddIdentity(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
ports = ops.convert_to_tensor(False, name="ports")
switch_op = control_flow_ops.switch(data, ports)
@@ -241,7 +241,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
def testSwitchMergeAddMul(self):
- with self.test_session():
+ with self.cached_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
ports = ops.convert_to_tensor(True, name="ports")
switch_op = control_flow_ops.switch(data, ports)
@@ -255,7 +255,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
def testLoop_false(self):
- with self.test_session():
+ with self.cached_session():
false = ops.convert_to_tensor(False)
n = constant_op.constant(10)
@@ -272,7 +272,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, result)
def testLoop_1(self):
- with self.test_session():
+ with self.cached_session():
zero = constant_op.constant(0)
one = constant_op.constant(1)
n = constant_op.constant(10)
@@ -298,7 +298,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, result)
def testLoop_2(self):
- with self.test_session():
+ with self.cached_session():
zero = constant_op.constant(0)
one = constant_op.constant(1)
n = constant_op.constant(10)
@@ -324,7 +324,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, result)
def testDifferentFrame(self):
- with self.test_session():
+ with self.cached_session():
data = array_ops.placeholder(dtypes.float32, shape=[])
enter_1 = gen_control_flow_ops.enter(data, "foo_1", False)
enter_2 = gen_control_flow_ops.enter(data, "foo_2", False)
@@ -333,7 +333,7 @@ class ControlFlowTest(test.TestCase):
res.eval(feed_dict={data: 1.0})
def testCondBool(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296297")
values = constant_op.constant(10)
@@ -352,7 +352,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([None], grad)
def testFetchable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
control_flow_ops.cond(
constant_op.constant(True), lambda: x + 2, lambda: x + 0)
@@ -367,7 +367,7 @@ class ControlFlowTest(test.TestCase):
sess.run(t, feed_dict={x: 3})
def testFeedable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(2)
i0 = constant_op.constant(0)
r = control_flow_ops.while_loop(lambda i: i < 1000,
@@ -384,10 +384,10 @@ class ControlFlowTest(test.TestCase):
sess.run(r, feed_dict={t: 3})
def testCondIndexedSlices(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296180")
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
x = ops.IndexedSlices(values, indices)
@@ -402,10 +402,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(0, ind)
def testCondSparseTensor(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113296161 (SparseTensors)")
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
[[0], [3]], dtype=dtypes.int64, name="indices")
@@ -422,10 +422,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r.values.get_shape(), (2,))
def testCondResource(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
rv = resource_variable_ops.ResourceVariable(True)
variables.global_variables_initializer().run()
t = ops.convert_to_tensor(1.0)
@@ -438,10 +436,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113293074")
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64)
@@ -484,17 +482,14 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, result)
def testCond_1(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
self._testCond_1(use_gpu=False)
- self._testCond_1(use_gpu=True)
+ # TODO(b/116526896): Enable GPU tests.
+ # self._testCond_1(use_gpu=True)
def testCond_2(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(10)
r = control_flow_ops.cond(
math_ops.less(1, 0), lambda: math_ops.add(x, 1),
@@ -503,10 +498,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(9, result)
def testCond_3(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(10)
pred = math_ops.less(1, 2)
fn1 = lambda: math_ops.add(x, 1)
@@ -518,10 +511,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(12, result)
def testCond_4(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
v3 = variables.Variable(7)
@@ -542,7 +535,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(7, v3.eval())
def testCond_5(self):
- with self.test_session():
+ with self.cached_session():
alive = constant_op.constant(True, name="alive")
count = constant_op.constant(0, name="count")
@@ -556,10 +549,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(4, count.eval())
def testCond_6(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable([7])
age = constant_op.constant(3)
@@ -573,7 +564,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.array([7]), result)
def testCond_7(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant(10)
y = constant_op.constant(200)
pred = math_ops.less(1, 2)
@@ -583,10 +574,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([11, 12], sess.run(r))
def testCondRef(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
x = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
@@ -599,10 +588,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([2.0], r.eval())
def testCondWithControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/79881896")
- with self.test_session() as sess:
+ with self.cached_session():
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -617,7 +606,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5, r.eval())
def testUninitializedRefIdentity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = gen_state_ops.variable(
shape=[1],
dtype=dtypes.float32,
@@ -641,7 +630,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual([1.0], sess.run(merged_op.output))
def testCondSwitchIdentity(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
# Make sure the recv identity is not removed by optimization.
@@ -658,7 +647,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondRecvIdentity(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
# Make sure the switch identity is not removed by optimization.
@@ -677,7 +666,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondGrad_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113346829 (gpu failure)")
graph = ops.Graph()
@@ -689,11 +678,11 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
grad = gradients_impl.gradients(r, [x])[0]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(1.0, grad.eval())
def testCondGrad_2(self):
- with self.test_session():
+ with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
x = constant_op.constant(10.0)
pred = math_ops.less(c, 2)
@@ -706,10 +695,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
def testCondGrad_3(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/110550782 (gradient w.r.t external variable)")
- with self.test_session():
+ with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
pred = math_ops.less(c, 2)
@@ -726,7 +715,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
def testNestedCond_Simple(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(0., name="X")
y = control_flow_ops.cond(
constant_op.constant(True), lambda: x,
@@ -741,10 +730,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, result.eval())
def testCondGrad_Gather(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113327884")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
pred = math_ops.less(c, 2)
@@ -768,7 +757,7 @@ class ControlFlowTest(test.TestCase):
# Microbenchmark: 256,000 iterations/s.
def testWhile_1(self):
- with self.test_session():
+ with self.cached_session():
n = constant_op.constant(0)
c = lambda x: math_ops.less(x, 10000)
b = lambda x: math_ops.add(x, 1)
@@ -776,7 +765,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10000, r.eval())
def testWhileExternalControlDependencies(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(0.0)
v.initializer.run()
increment = v.assign_add(1.0)
@@ -791,7 +780,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(v.eval(), 1.0)
def testWhileExternalControlDependenciesNoInput(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(0.0)
v.initializer.run()
increment = v.assign_add(1.0)
@@ -806,7 +795,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(v.eval(), 1.0)
def testWhileWithRefs_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variables.Variable(0)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 100)
@@ -830,19 +819,19 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, value_x)
def testWhile_2(self):
- with self.test_session():
+ with self.cached_session():
s = constant_op.constant(0)
r = isum(s)
self.assertAllEqual(45, r.eval())
def testWhileWithMaximumIterations(self):
- with self.test_session():
+ with self.cached_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
def testWhileWithMaximumIterationsAndSingleArgument(self):
- with self.test_session():
+ with self.cached_session():
r = control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
@@ -916,7 +905,7 @@ class ControlFlowTest(test.TestCase):
_ = gradients_impl.gradients(loop_with_maxiter, v)
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
v = constant_op.constant(1.0)
@@ -1019,7 +1008,7 @@ class ControlFlowTest(test.TestCase):
# Have more than 10 parallel iterations and hence exercise k-bound
# most of the time.
def testWhile_3(self):
- with self.test_session():
+ with self.cached_session():
def compute(i, m, c, o):
m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
@@ -1039,7 +1028,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10100, result)
def testWhile_4(self):
- with self.test_session():
+ with self.cached_session():
def compute(i, m, c, o):
m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
@@ -1060,7 +1049,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42, result)
def testWhile_5(self):
- with self.test_session():
+ with self.cached_session():
def compute(i, c, o):
c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
@@ -1088,7 +1077,7 @@ class ControlFlowTest(test.TestCase):
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("/cpu:0"):
c = constant_op.constant(2)
i0 = constant_op.constant(0)
@@ -1134,7 +1123,7 @@ class ControlFlowTest(test.TestCase):
self._testWhile_Gpu_1(use_gpu=True)
def testWhileShape(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0)
m = array_ops.ones([2, 2])
c = lambda i, j: math_ops.less(i, 2)
@@ -1151,7 +1140,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(np.ones((8, 8)), r.eval())
def testWhileWithNonTensorInput_Scalar(self):
- with self.test_session():
+ with self.cached_session():
n = 0
c = lambda x: x < 10000
b = lambda x: x + 1
@@ -1159,7 +1148,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10000, r.eval())
def testWhileWithNonTensorInput_Vector(self):
- with self.test_session():
+ with self.cached_session():
n = np.array([0]) # Note, [0] would not work here; that is a list
c = lambda x: x[0] < 10000
b = lambda x: array_ops.stack([x[0] + 1])
@@ -1167,7 +1156,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([10000], r.eval())
def testWhileShapeInference(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0)
m = array_ops.ones([2, 2])
c = lambda i, j: math_ops.less(i, 2)
@@ -1192,7 +1181,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [i, m])
def testWhileShapeInferenceSparseTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
[[0], [3]], dtype=dtypes.int64, name="indices")
@@ -1223,7 +1212,7 @@ class ControlFlowTest(test.TestCase):
[i.get_shape(), tensor_shape.TensorShape([5])])
def testWhileShapeInferenceIndexedSlices(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
indices = constant_op.constant([0, 3], name="indices")
shape = constant_op.constant([10, 2], name="dense_shape")
@@ -1313,7 +1302,7 @@ class ControlFlowTest(test.TestCase):
self._testNestedWhile_2(use_gpu=True)
def testWhileWithControl_1(self):
- with self.test_session():
+ with self.cached_session():
n = constant_op.constant(0)
r = constant_op.constant(0)
condition = lambda n_, r_: math_ops.less(n_, 10)
@@ -1329,7 +1318,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(12, res[1].eval())
def testWhileWithControl_2(self):
- with self.test_session():
+ with self.cached_session():
r = constant_op.constant(0)
condition = lambda r_: math_ops.less(r_, 10)
@@ -1343,7 +1332,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(12, res.eval())
def testWhileWithControl_3(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
c = constant_op.constant(1)
x0 = constant_op.constant(0)
@@ -1352,7 +1341,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileWithControl_4(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
c = constant_op.constant(1)
x0 = constant_op.constant(0)
@@ -1362,7 +1351,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileWithControl_5(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
c = constant_op.constant(1)
x0 = constant_op.constant(0)
@@ -1375,12 +1364,12 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, sess.run(r, {b: True}))
def testWhileCondWithControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const_true = lambda: constant_op.constant(True)
const_false = lambda: constant_op.constant(False)
cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false)
@@ -1392,10 +1381,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, sess.run(loop))
def testWhileCondWithControl_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
i0 = constant_op.constant(0)
@@ -1417,10 +1406,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(65536.0, v.eval())
def testWhileCondExitControl(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(1)
def false_branch():
@@ -1443,10 +1432,10 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
+ if control_flow_ops.ENABLE_COND_V2:
+ return unittest.skip("b/113294340 (enable while_v2)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
c = lambda x: math_ops.less(x, 10)
b = lambda x: math_ops.add(x, 1)
@@ -1456,10 +1445,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
+ if control_flow_ops.ENABLE_COND_V2:
+ return unittest.skip("b/113294340 (enable while_v2)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(0)
c = lambda x: math_ops.less(x, 10)
b = lambda x: math_ops.add(x, 1)
@@ -1469,7 +1458,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294340 (enable while_v2)")
with self.test_session(use_gpu=use_gpu) as sess:
@@ -1498,10 +1487,10 @@ class ControlFlowTest(test.TestCase):
self._testCondWhile_3(use_gpu=True)
def testWhileCond_1(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
- with self.test_session():
+ with self.cached_session():
i = ops.convert_to_tensor(0, name="i")
n = ops.convert_to_tensor(10, name="n")
one = ops.convert_to_tensor(1, name="one")
@@ -1516,10 +1505,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_2(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
c = lambda x: math_ops.less(x, 10)
b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
@@ -1527,10 +1516,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testWhileCond_3(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(0)
c = lambda x: math_ops.less(x, 10)
# pylint: disable=undefined-variable
@@ -1544,7 +1533,7 @@ class ControlFlowTest(test.TestCase):
# NOTE: It is ok to have parallel_iterations > 1
def testWhileUpdateVariable_1(self):
- with self.test_session():
+ with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
n = constant_op.constant(0)
@@ -1566,7 +1555,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
def testWhileUpdateVariable_2(self):
- with self.test_session():
+ with self.cached_session():
select1 = variables.Variable([3.0, 4.0, 5.0])
select2 = variables.Variable([3.0, 4.0, 5.0])
n = constant_op.constant(0)
@@ -1592,7 +1581,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
def testWhileUpdateVariable_3(self):
- with self.test_session():
+ with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
n = constant_op.constant(0)
@@ -1614,7 +1603,7 @@ class ControlFlowTest(test.TestCase):
# b/24814703
def testWhileUpdateVariable_4(self):
- with self.test_session():
+ with self.cached_session():
var_a = variables.Variable(0, name="a")
var_b = variables.Variable(0, name="b")
variables.global_variables_initializer().run()
@@ -1642,7 +1631,7 @@ class ControlFlowTest(test.TestCase):
# b/24736492
def testWhileUpdateVariable_5(self):
- with self.test_session():
+ with self.cached_session():
# Create some variables.
var_a = variables.Variable(0, name="a")
var_b = variables.Variable(0, name="b")
@@ -1672,7 +1661,7 @@ class ControlFlowTest(test.TestCase):
# b/24814668
def testWhileUpdateVariable_6(self):
- with self.test_session():
+ with self.cached_session():
# Create some variables.
var_a = variables.Variable(0, name="a")
var_b = variables.Variable(0, name="b")
@@ -1701,7 +1690,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, var_a.eval())
def testWhileQueue_1(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
i = constant_op.constant(0)
@@ -1719,7 +1708,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([i], q.dequeue().eval())
def testWhileStack_1(self):
- with self.test_session():
+ with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
i = constant_op.constant(0)
@@ -1753,7 +1742,7 @@ class ControlFlowTest(test.TestCase):
def _testWhileGrad_ColocateGradients(self, colocate):
gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
- ) else "/device:GPU:0"
+ ) else "/device:CPU:0"
graph = ops.Graph()
with graph.as_default():
@@ -1783,7 +1772,7 @@ class ControlFlowTest(test.TestCase):
else:
self.assertFalse(gpu_dev_name in dev)
- with self.test_session(graph=graph) as sess:
+ with self.session(graph=graph) as sess:
self.assertAllClose(1024.0, sess.run(r))
def testWhileGrad_ColocateGradients(self):
@@ -1791,7 +1780,7 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_ColocateGradients(colocate=True)
def testWhileGrad_Square(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
b = math_ops.square
@@ -1802,7 +1791,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r.eval())
def testWhileGrad_Shape(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[None])
v = constant_op.constant([2.0], name="v")
n = constant_op.constant(0, name="n")
@@ -1819,7 +1808,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
def testWhileGrad_BaseShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, [None])
v0 = constant_op.constant([2.0, 2.0], name="v")
c = lambda v: constant_op.constant(False)
@@ -1831,7 +1820,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
def testWhileGrad_MultipleUses(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
b = math_ops.square
@@ -1842,7 +1831,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(524288.0, r.eval())
def testWhileGrad_LoopAdd(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
b = math_ops.square
@@ -1872,7 +1861,7 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
@@ -1901,7 +1890,7 @@ class ControlFlowTest(test.TestCase):
self._testNestedWhileCondWhileGrad(use_gpu=True)
def testWhileGrad_Variable(self):
- with self.test_session():
+ with self.cached_session():
a = variables.Variable(3.0)
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
@@ -1913,10 +1902,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/110550782 (gradient w.r.t external variable)")
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
x = array_ops.placeholder(dtypes.float32, shape=None)
c = lambda n: math_ops.less(n, 10.0)
@@ -1931,7 +1920,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
def testGradInWhileWrtInitialLoopVal(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
y = x + 1
@@ -1948,7 +1937,7 @@ class ControlFlowTest(test.TestCase):
control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
def testWhileGradInWhile(self):
- with self.test_session():
+ with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
x = array_ops.placeholder(dtypes.float32, shape=None)
c = lambda n: math_ops.less(n, 10.0)
@@ -1964,7 +1953,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
def testCondGradInNestedWhiles(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
@@ -1978,13 +1967,13 @@ class ControlFlowTest(test.TestCase):
i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i_val, x_val = sess.run([i, x])
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
def testWhile_NestedInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
@@ -2011,7 +2000,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r_flattened))
def testWhile_NestedBadArityFails(self):
- with self.test_session():
+ with self.cached_session():
named = collections.namedtuple("named", ("a", "b"))
loop_vars = [
named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
@@ -2027,7 +2016,7 @@ class ControlFlowTest(test.TestCase):
control_flow_ops.while_loop(c, b, loop_vars)
def testWhileGrad_ys_xs(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(3.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -2050,7 +2039,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(120.0, r[0].eval())
def testWhileGrad_Dependency(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0, name="i")
x = constant_op.constant(2.0, name="x")
@@ -2069,7 +2058,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r[0].eval())
def testWhileGrad_NoGradient(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
b = math_ops.square
@@ -2079,7 +2068,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1.0, r[0].eval())
def testWhileGrad_NoDependency(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variable = variables.Variable(array_ops.ones([2, 3]))
duration = array_ops.zeros([], dtype=dtypes.int32)
@@ -2099,7 +2088,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))
def testWhileGrad_Const(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c0 = constant_op.constant(0.0, name="c0")
c1 = constant_op.constant(1.0, name="c1")
duration = constant_op.constant(0, name="t")
@@ -2118,7 +2107,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(0.0, sess.run(grad[0]))
def testWhileGrad_SerialTwoLoops(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0, name="i")
x = constant_op.constant(2.0, name="x")
@@ -2136,7 +2125,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r[0].eval())
def testWhileGrad_ParallelTwoLoops(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0, name="i")
x = constant_op.constant(2.0, name="x")
@@ -2155,7 +2144,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(64.0, r[0].eval())
def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(0, name="i")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(1.0, name="y")
@@ -2196,7 +2185,7 @@ class ControlFlowTest(test.TestCase):
self._testNestedWhileGrad_Simple(use_gpu=True)
def testNestedWhileGrad_SerialInner(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(1.0)
def inner_loop1(s):
@@ -2219,7 +2208,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(256.0, r.eval())
def testNestedWhileGrad_ParallelInner(self):
- with self.test_session():
+ with self.cached_session():
v = constant_op.constant(1.0)
def inner_loop1(s):
@@ -2244,7 +2233,7 @@ class ControlFlowTest(test.TestCase):
def testNestedWhileGrad_ParallelIterations(self):
# Make sure the stack pushes and pops of an inner loop are executed in
# the sequential order of the iterations of its outer loop.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def inner_loop(t):
fn = lambda n: n + math_ops.square(var)
@@ -2280,14 +2269,14 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r.eval())
def testWhileCondGrad_Simple(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113294377 (unknown shape)")
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
def testWhileCondGrad_UnknownShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = array_ops.placeholder(dtypes.float32)
n = ops.convert_to_tensor(100.0, name="n")
one = ops.convert_to_tensor(1.0, name="one")
@@ -2304,7 +2293,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(1024.0, r)
def testWhileGrad_Concat(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variable_scope.get_variable("x", initializer=[[1., 2.]])
i0 = constant_op.constant(0)
h0 = array_ops.zeros([0, 2])
@@ -2327,7 +2316,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
def testWhileWithRefsWithGradients_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variables.Variable(0.)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 10)
@@ -2355,7 +2344,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(73, value_x_grad)
def testWhileGrad_IndexedSlices(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant([0, 3], name="indices")
shape = constant_op.constant([10], name="dense_shape")
@@ -2376,7 +2365,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
def testWhileGrad_SparseTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
[[0], [3]], dtype=dtypes.int64, name="indices")
@@ -2398,7 +2387,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
def testCallGradInLoop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i0 = constant_op.constant(0)
params = constant_op.constant(5.0)
params_1 = math_ops.square(params)
@@ -2417,7 +2406,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(600.0, sess.run(output_grad)[1])
def testWhileAndTensorArray(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
param = constant_op.constant(2.0)
n0 = constant_op.constant(0)
y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
@@ -2436,7 +2425,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(107520.0, sess.run(r))
def testWhileGrad_StopGrad(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(3.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -2479,7 +2468,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(32.0, r.eval())
def testWhileGrad_StopGradInside(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(3.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -2498,7 +2487,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(156.0, r.eval())
def testWhileGrad_StopGradInsideNoShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
y = array_ops.placeholder(dtypes.float32)
@@ -2534,7 +2523,7 @@ class ControlFlowTest(test.TestCase):
gradients_impl.gradients(grad_theta_stopped, theta)
def testStopGradOnWhileGrad(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(2.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -2562,7 +2551,7 @@ class ControlFlowTest(test.TestCase):
_, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
dy_dq, = gradients_impl.gradients(y, q)
self.assertIsNotNone(dy_dq)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(q.initializer)
self.assertAllClose([0., 0.], sess.run(dy_dq))
@@ -2579,7 +2568,7 @@ class ControlFlowTest(test.TestCase):
_, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
dy_dq, = gradients_impl.gradients(y, q)
self.assertIsNotNone(dy_dq)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(q.initializer)
self.assertAllClose([1., 1.], sess.run(dy_dq))
@@ -2607,7 +2596,7 @@ class ControlFlowTest(test.TestCase):
self.assertIsNotNone(grad)
def testStopGradMultiFlows(self):
- with self.test_session():
+ with self.cached_session():
def body(i, y, r):
x = variable_scope.get_variable(
@@ -2633,10 +2622,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(5.0, result.eval())
def testOneValueCond(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
one = ops.convert_to_tensor(1, name="one")
two = ops.convert_to_tensor(2, name="two")
@@ -2651,10 +2638,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([2], i.eval(feed_dict={c: 0}))
def testExampleCond(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("b/111124878 (don't return tuple)")
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor([-2.0, 2.0], name="x")
d = array_ops.placeholder(dtypes.int32, shape=[])
@@ -2669,10 +2654,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
def testCase(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
z = constant_op.constant(3)
@@ -2724,10 +2709,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
def testCaseSideEffects(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/112477618 (Operation returned from cond)")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
v2 = variables.Variable(-1)
@@ -2762,10 +2747,10 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
def testOneOpCond(self):
- if control_flow_ops._ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_COND_V2:
return unittest.skip("b/113324949 (ref vars)")
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
one = ops.convert_to_tensor(1)
@@ -2793,7 +2778,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(2, v.eval())
def testWithOpsDependencies(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = variables.Variable(0.0)
c = constant_op.constant(10)
@@ -2816,7 +2801,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(0.0, real_v_val)
def testWithTensorDependencies(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(0.0)
c1 = constant_op.constant(10)
c2 = constant_op.constant(20)
@@ -2842,7 +2827,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(0.0, v.eval())
def testWithIndexedSlicesDependencies(self):
- with self.test_session():
+ with self.cached_session():
v = variables.Variable(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
@@ -2886,7 +2871,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())
def testGroup(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v1 = variables.Variable([0.0])
v2 = variables.Variable([1.0])
@@ -2997,7 +2982,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(None, s.get_shape())
def testRunLoopTensor(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor_list = []
def condition(t):
@@ -3021,7 +3006,7 @@ class ControlFlowTest(test.TestCase):
def func(x):
return np.square(x)
- with self.test_session():
+ with self.cached_session():
r = control_flow_ops.while_loop(
lambda i, v: i < 4,
lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
@@ -3035,7 +3020,7 @@ class ControlFlowTest(test.TestCase):
def func(x):
return math_ops.square(math_ops.square(x))
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(2.0, dtypes.float32)
r = control_flow_ops.while_loop(
lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
@@ -3174,7 +3159,7 @@ class TupleTest(test.TestCase):
def testTensors(self):
for v1_first in [True, False]:
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable([1.0])
add1 = math_ops.add(
control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
@@ -3204,7 +3189,7 @@ class TupleTest(test.TestCase):
def testIndexedSlices(self):
for v1_first in [True, False]:
- with self.test_session():
+ with self.cached_session():
v1 = variables.Variable(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
np.float32))
@@ -3243,7 +3228,7 @@ class TupleTest(test.TestCase):
v1.eval())
def testAcceptTensorsAsControlInputs(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(0)
assign = state_ops.assign(var, 1)
t, = control_flow_ops.tuple(
@@ -3408,6 +3393,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
+@test_util.with_cond_v2
class EagerTest(test.TestCase):
def testCond(self):
diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py
index fcba456004..2d6d8a8051 100644
--- a/tensorflow/python/kernel_tests/conv1d_test.py
+++ b/tensorflow/python/kernel_tests/conv1d_test.py
@@ -53,7 +53,7 @@ class Conv1DTest(test.TestCase):
self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
def testConv1DTranspose(self):
- with self.test_session():
+ with self.cached_session():
stride = 2
# Input, output: [batch, width, depth]
diff --git a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
index be299beee4..644a151710 100644
--- a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class Conv2DBackpropFilterGradTest(test.TestCase):
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
for padding in ["SAME", "VALID"]:
for stride in [1, 2]:
np.random.seed(1)
diff --git a/tensorflow/python/kernel_tests/conv2d_transpose_test.py b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
index 27804be65c..cbdd2c5991 100644
--- a/tensorflow/python/kernel_tests/conv2d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_transpose_test.py
@@ -37,7 +37,7 @@ from tensorflow.python.platform import test
class Conv2DTransposeTest(test.TestCase):
def testConv2DTransposeSingleStride(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 1, 1, 1]
# Input, output: [batch, height, width, depth]
@@ -75,7 +75,7 @@ class Conv2DTransposeTest(test.TestCase):
self.assertAllClose(target, value[n, h, w, k])
def testConv2DTransposeSame(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 2, 2, 1]
# Input, output: [batch, height, width, depth]
@@ -108,7 +108,7 @@ class Conv2DTransposeTest(test.TestCase):
self.assertAllClose(target, value[n, h, w, k])
def testConv2DTransposeValid(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 2, 2, 1]
# Input, output: [batch, height, width, depth]
@@ -163,7 +163,7 @@ class Conv2DTransposeTest(test.TestCase):
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn_ops.conv2d_transpose(
diff --git a/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py b/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
index 85264ef876..89b64068ac 100644
--- a/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_backprop_filter_v2_grad_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class Conv3DBackpropFilterV2GradTest(test.TestCase):
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
for padding in ["SAME", "VALID"]:
for stride in [1, 2]:
np.random.seed(1)
diff --git a/tensorflow/python/kernel_tests/conv3d_transpose_test.py b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
index 289ae29fce..2527b83769 100644
--- a/tensorflow/python/kernel_tests/conv3d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
@@ -32,7 +32,7 @@ from tensorflow.python.platform import test
class Conv3DTransposeTest(test.TestCase):
def testConv3DTransposeSingleStride(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 1, 1, 1, 1]
# Input, output: [batch, depth, height, width, channel]
@@ -82,7 +82,7 @@ class Conv3DTransposeTest(test.TestCase):
self.assertAllClose(target, value[n, d, h, w, k])
def testConv3DTransposeSame(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -134,7 +134,7 @@ class Conv3DTransposeTest(test.TestCase):
def testConv3DTransposeOutputShapeType(self):
# Test case for GitHub issue 18887
for dtype in [dtypes.int32, dtypes.int64]:
- with self.test_session():
+ with self.cached_session():
x_shape = [2, 5, 6, 4, 3]
y_shape = [2, 5, 6, 4, 2]
f_shape = [3, 3, 3, 2, 3]
@@ -149,7 +149,7 @@ class Conv3DTransposeTest(test.TestCase):
output.eval()
def testConv3DTransposeValid(self):
- with self.test_session():
+ with self.cached_session():
strides = [1, 2, 2, 2, 1]
# Input, output: [batch, depth, height, width, depth]
@@ -209,7 +209,7 @@ class Conv3DTransposeTest(test.TestCase):
np.random.seed(1) # Make it reproducible.
x_val = np.random.random_sample(x_shape).astype(np.float64)
f_val = np.random.random_sample(f_shape).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(x_val, name="x", dtype=dtypes.float32)
f = constant_op.constant(f_val, name="f", dtype=dtypes.float32)
output = nn_ops.conv3d_transpose(
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index 0b531125f3..6794464e3a 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -108,7 +108,7 @@ class Conv3DTest(test.TestCase):
use_gpu=use_gpu)
results.append(result)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = sess.run(results)
for value in values:
print("expected = ", expected)
@@ -183,7 +183,7 @@ class Conv3DTest(test.TestCase):
expected_results.append(expected)
computed_results.append(computed)
tolerance = 1e-2 if use_gpu else 1e-5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected_values = sess.run(expected_results)
computed_values = sess.run(computed_results)
for e_value, c_value in zip(expected_values, computed_values):
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 00de94f004..ea611497d9 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -1474,7 +1474,7 @@ class Conv2DTest(test.TestCase):
padding="SAME")
def testOpEdgeCases(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Illegal strides.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"strides in the batch and depth"):
@@ -1539,7 +1539,7 @@ class DepthwiseConv2DTest(test.TestCase):
# numbers from 1.
x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
t1 = constant_op.constant(x1, shape=tensor_in_sizes)
t1.set_shape(tensor_in_sizes)
t2 = constant_op.constant(x2, shape=filter_in_sizes)
diff --git a/tensorflow/python/kernel_tests/cross_grad_test.py b/tensorflow/python/kernel_tests/cross_grad_test.py
index f040ac6055..0bd4006d6a 100644
--- a/tensorflow/python/kernel_tests/cross_grad_test.py
+++ b/tensorflow/python/kernel_tests/cross_grad_test.py
@@ -27,7 +27,7 @@ from tensorflow.python.platform import test
class CrossOpTest(test.TestCase):
def testGradientRandomValues(self):
- with self.test_session():
+ with self.cached_session():
us = [2, 3]
u = array_ops.reshape(
[0.854, -0.616, 0.767, 0.725, -0.927, 0.159], shape=us)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
new file mode 100644
index 0000000000..8028f93a8c
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
@@ -0,0 +1,878 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for binary coefficient-wise operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+_ADD = lambda x, y: x + y
+_SUB = lambda x, y: x - y
+_MUL = lambda x, y: x * y
+_POW = lambda x, y: x**y
+_TRUEDIV = lambda x, y: x / y
+_FLOORDIV = lambda x, y: x // y
+_MOD = lambda x, y: x % y
+
+
+# TODO(zongheng): it'd be great to factor out this function and various random
+# SparseTensor gen funcs.
+def _sparsify(x, thresh=0.5, index_dtype=np.int64):
+ x[x < thresh] = 0
+
+ non_zero = np.where(x)
+ x_indices = np.vstack(non_zero).astype(index_dtype).T
+ x_values = x[non_zero]
+ x_shape = x.shape
+
+ return sparse_tensor.SparseTensor(
+ indices=x_indices, values=x_values, dense_shape=x_shape), x_values
+
+
+def _default_tolerance(dtype):
+ """Returns a sensible default tolerance for comparing results of a given type.
+
+ Args:
+ dtype: A datatype.
+ """
+ if dtype == np.float16:
+ return 5e-3
+ elif dtype in (np.float32, np.complex64):
+ return 1e-3
+ elif dtype in (np.float64, np.complex128):
+ return 1e-5
+ else:
+ return None # Fail fast for unexpected types
+
+
+class BinaryOpTest(test.TestCase):
+
+ def _compareCpu(self, x, y, np_func, tf_func, also_compare_variables=False):
+ np_ans = np_func(x, y)
+ with self.test_session(use_gpu=False):
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ tf_cpu = out.eval()
+ # Test that the op takes precedence over numpy operators.
+ np_left = tf_func(x, iny).eval()
+ np_right = tf_func(inx, y).eval()
+
+ if also_compare_variables:
+ var_x = variables.Variable(x)
+ var_y = variables.Variable(y)
+ variables.global_variables_initializer().run()
+ print(type(x), type(y), type(var_x), type(var_y))
+ print(type(tf_func(x, var_y)), type(tf_func(var_x, y)))
+ np_var_left = tf_func(x, var_y).eval()
+ np_var_right = tf_func(var_x, y).eval()
+
+ if np_ans.dtype != np.object:
+ self.assertAllClose(np_ans, tf_cpu)
+ self.assertAllClose(np_ans, np_left)
+ self.assertAllClose(np_ans, np_right)
+ if also_compare_variables:
+ self.assertAllClose(np_ans, np_var_left)
+ self.assertAllClose(np_ans, np_var_right)
+ self.assertShapeEqual(np_ans, out)
+
+ _GRAD_TOL = {
+ dtypes_lib.float16: 1e-3,
+ dtypes_lib.float32: 1e-3,
+ dtypes_lib.complex64: 1e-2,
+ dtypes_lib.float64: 1e-5,
+ dtypes_lib.complex128: 1e-4
+ }
+
+ def _compareGradientX(self,
+ x,
+ y,
+ np_func,
+ tf_func,
+ numeric_gradient_type=None):
+ z = np_func(x, y)
+ zs = list(z.shape)
+ with self.cached_session():
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ if x.dtype in (np.float32, np.float64):
+ out = 1.1 * tf_func(inx, iny)
+ else:
+ out = tf_func(inx, iny)
+ xs = list(x.shape)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, xs, out, zs, x_init_value=x)
+ if numeric_gradient_type is not None:
+ xf = x.astype(numeric_gradient_type)
+ yf = y.astype(numeric_gradient_type)
+ inxf = ops.convert_to_tensor(xf)
+ inyf = ops.convert_to_tensor(yf)
+ outf = tf_func(inxf, inyf)
+ _, jacob_n = gradient_checker.compute_gradient(
+ inxf, xs, outf, zs, x_init_value=xf, delta=1e-3)
+ jacob_n = jacob_n.astype(x.dtype)
+ tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
+ self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
+
+ def _compareGradientY(self,
+ x,
+ y,
+ np_func,
+ tf_func,
+ numeric_gradient_type=None):
+ z = np_func(x, y)
+ zs = list(z.shape)
+ with self.cached_session():
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ if x.dtype in (np.float32, np.float64):
+ out = 1.1 * tf_func(inx, iny)
+ else:
+ out = tf_func(inx, iny)
+ ys = list(np.shape(y))
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ iny, ys, out, zs, x_init_value=y)
+ if numeric_gradient_type is not None:
+ xf = x.astype(numeric_gradient_type)
+ yf = y.astype(numeric_gradient_type)
+ inxf = ops.convert_to_tensor(xf)
+ inyf = ops.convert_to_tensor(yf)
+ outf = tf_func(inxf, inyf)
+ _, jacob_n = gradient_checker.compute_gradient(
+ inyf, ys, outf, zs, x_init_value=yf)
+ jacob_n = jacob_n.astype(x.dtype)
+ tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
+ self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
+
+ def _compareGpu(self, x, y, np_func, tf_func):
+ np_ans = np_func(x, y)
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ out = tf_func(inx, iny)
+ tf_gpu = out.eval()
+ self.assertAllClose(np_ans, tf_gpu)
+ self.assertShapeEqual(np_ans, out)
+ # TODO(zhifengc/ke): make gradient checker work on GPU.
+
+ def _compareBoth(self, x, y, np_func, tf_func, also_compare_variables=False):
+ self._compareCpu(x, y, np_func, tf_func, also_compare_variables)
+ if x.dtype in (np.float16, np.float32, np.float64, np.complex64,
+ np.complex128):
+ if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta,
+ math_ops.polygamma):
+ self._compareGradientX(x, y, np_func, tf_func)
+ self._compareGradientY(x, y, np_func, tf_func)
+ if tf_func in (math_ops.zeta, math_ops.polygamma):
+ # These methods only support gradients in the second parameter
+ self._compareGradientY(x, y, np_func, tf_func)
+ self._compareGpu(x, y, np_func, tf_func)
+
+ def testFloatBasic(self):
+ x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float32)
+ y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(x, y, np.add, math_ops.add, also_compare_variables=True)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
+ self._compareBoth(x, y, np.arctan2, math_ops.atan2)
+ x1 = np.random.randn(5, 6).astype(np.float32)
+ x2 = np.random.randn(5, 6).astype(np.float32)
+ # Remove tiny values--atan2 gradients are flaky near the origin.
+ x1[np.abs(x1) < 0.05] = 0.05 * np.sign(x1[np.abs(x1) < 0.05])
+ x2[np.abs(x2) < 0.05] = 0.05 * np.sign(x2[np.abs(x2) < 0.05])
+ self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
+ x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
+ math_ops.igamma)
+ self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
+ math_ops.igammac)
+ # Need x > 1
+ self._compareBoth(x_pos_small + 1, a_pos_small, special.zeta,
+ math_ops.zeta)
+ n_small = np.arange(0, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(n_small, x_pos_small, special.polygamma,
+ math_ops.polygamma)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ def testFloatDifferentShapes(self):
+ x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
+ y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
+ with self.cached_session() as sess:
+ inx = ops.convert_to_tensor(x)
+ iny = ops.convert_to_tensor(y)
+ s = math_ops.reduce_sum(inx * iny)
+ gx, gy = sess.run(gradients_impl.gradients(s, [inx, iny]))
+ # gx is simply the broadcasted y
+ self.assertAllEqual(gx,
+ np.array([1, 1, 2, 2]).reshape(2, 2).astype(np.float32))
+ # gy is x's column summed up
+ self.assertAllEqual(gy, np.array([3, 7]).reshape(2, 1).astype(np.float32))
+
+ def testFloatVariableOverload(self):
+ x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.int32)
+ y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
+ var_x = variables.Variable(x)
+ var_y = variables.Variable(y)
+ with self.cached_session() as sess:
+ sess.run([var_x.initializer, var_y.initializer])
+ left_result = (var_x * y).eval()
+ right_result = (x * var_y).eval()
+ np_result = x * y
+ self.assertAllEqual(np_result, left_result)
+ self.assertAllEqual(np_result, right_result)
+
+ def testDoubleBasic(self):
+ x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float64)
+ y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float64)
+ self._compareBoth(x, y, np.add, math_ops.add)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
+ self._compareBoth(x, y, np.arctan2, math_ops.atan2)
+ x1 = np.random.randn(7, 4).astype(np.float64)
+ x2 = np.random.randn(7, 4).astype(np.float64)
+ # Remove tiny values--atan2 gradients are flaky near the origin.
+ x1[np.abs(x1) < 0.5] = 0.5 * np.sign(x1[np.abs(x1) < 0.5])
+ x2[np.abs(x2) < 0.5] = 0.5 * np.sign(x2[np.abs(x2) < 0.5])
+ self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
+ x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
+ self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
+ math_ops.igamma)
+ self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
+ math_ops.igammac)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ def testUint8Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
+ self._compareBoth(x, y, np.add, math_ops.add)
+
+ def testInt8Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.multiply, _MUL)
+
+ def testInt16Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int16)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int16)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.multiply, _MUL)
+
+ def testUint16Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint16)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint16)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+
+ def testInt32Basic(self):
+ x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
+ self._compareBoth(x, y, np.add, math_ops.add)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.mod, math_ops.mod)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+ self._compareBoth(x, y, np.mod, _MOD)
+ # _compareBoth tests on GPU only for floating point types, so test
+ # _MOD for int32 on GPU by calling _compareGpu
+ self._compareGpu(x, y, np.mod, _MOD)
+
+ def testInt64Basic(self):
+ x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
+ y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
+ self._compareBoth(x, y, np.mod, math_ops.mod)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y, np.true_divide, _TRUEDIV)
+ self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
+ self._compareBoth(x, y, np.mod, _MOD)
+
+ def testComplex64Basic(self):
+ x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
+ np.complex64)
+ y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
+ np.complex64)
+ self._compareBoth(x, y, np.add, math_ops.add)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+
+ def testComplex128Basic(self):
+ x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
+ np.complex128)
+ y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
+ np.complex128)
+ self._compareBoth(x, y, np.add, math_ops.add)
+ self._compareBoth(x, y, np.subtract, math_ops.subtract)
+ self._compareBoth(x, y, np.multiply, math_ops.multiply)
+ self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
+ self._compareBoth(x, y, np.add, _ADD)
+ self._compareBoth(x, y, np.subtract, _SUB)
+ self._compareBoth(x, y, np.multiply, _MUL)
+ self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
+
+ def testStringComparison(self):
+ x = np.array([["abc", "bh"], ["c", ""]])
+ y = np.array([["abc", "bh"], ["def", "hi"]])
+ with self.test_session(use_gpu=False) as sess:
+ cmp_eq = math_ops.equal(x, y)
+ cmp_not_eq = math_ops.not_equal(x, y)
+ values = sess.run([cmp_eq, cmp_not_eq])
+ self.assertAllEqual([[True, True], [False, False]], values[0])
+ self.assertAllEqual([[False, False], [True, True]], values[1])
+
+ def testString(self):
+ x = np.array([["x_0_0", "x_0_1", "x_0_2"], ["x_1_0", "x_1_1", "x_1_2"],
+ ["x_2_0", "x_2_1", "x_2_2"]],
+ dtype=np.object)
+ y = np.array([["y_0_0", "y_0_1", "y_0_2"], ["y_1_0", "y_1_1", "y_1_2"],
+ ["y_2_0", "y_2_1", "y_2_2"]],
+ dtype=np.object)
+ z = np.array([["z_0", "z_1", "z_2"]], dtype=np.object)
+ w = np.array("w", dtype=np.object)
+ self._compareCpu(x, y, _ADD, _ADD)
+ self._compareCpu(x, z, _ADD, _ADD)
+ self._compareCpu(x, w, _ADD, _ADD)
+ self._compareCpu(z, w, _ADD, _ADD)
+
+ def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
+ if dtype in (np.complex64, np.complex128):
+ x = (1 + np.linspace(0, 2 + 3j, np.prod(xs))).astype(dtype).reshape(xs)
+ y = (1 + np.linspace(0, 2 - 2j, np.prod(ys))).astype(dtype).reshape(ys)
+ else:
+ x = (1 + np.linspace(0, 5, np.prod(xs))).astype(dtype).reshape(xs)
+ y = (1 + np.linspace(0, 5, np.prod(ys))).astype(dtype).reshape(ys)
+ self._compareCpu(x, y, np_func, tf_func)
+ if x.dtype in (np.float16, np.float32, np.float64):
+ # TODO(aselle): Make the test work for dtypes:
+ # (np.complex64, np.complex128).
+ if tf_func not in (_FLOORDIV, math_ops.floordiv):
+ if x.dtype == np.float16:
+ # Compare fp16 theoretical gradients to fp32 numerical gradients,
+ # since fp16 numerical gradients are too imprecise unless great
+ # care is taken with choosing the inputs and the delta. This is
+ # a weaker check (in particular, it does not test the op itself,
+ # only its gradient), but it's much better than nothing.
+ self._compareGradientX(x, y, np_func, tf_func, np.float)
+ self._compareGradientY(x, y, np_func, tf_func, np.float)
+ else:
+ self._compareGradientX(x, y, np_func, tf_func)
+ self._compareGradientY(x, y, np_func, tf_func)
+ self._compareGpu(x, y, np_func, tf_func)
+
+ # TODO(josh11b,vrv): Refactor this to use parameterized tests.
+ def _testBCastByFunc(self, funcs, xs, ys):
+ dtypes = [
+ np.float16,
+ np.float32,
+ np.float64,
+ np.int32,
+ np.int64,
+ np.complex64,
+ np.complex128,
+ ]
+ for dtype in dtypes:
+ for (np_func, tf_func) in funcs:
+ if (dtype in (np.complex64, np.complex128) and
+ tf_func in (_FLOORDIV, math_ops.floordiv)):
+ continue # floordiv makes no sense for complex numbers
+ self._compareBCast(xs, ys, dtype, np_func, tf_func)
+ self._compareBCast(ys, xs, dtype, np_func, tf_func)
+
+ def _testBCastA(self, xs, ys):
+ funcs = [
+ (np.add, math_ops.add),
+ (np.add, _ADD),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastB(self, xs, ys):
+ funcs = [
+ (np.subtract, math_ops.subtract),
+ (np.subtract, _SUB),
+ (np.power, math_ops.pow),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastC(self, xs, ys):
+ funcs = [
+ (np.multiply, math_ops.multiply),
+ (np.multiply, _MUL),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def _testBCastD(self, xs, ys):
+ funcs = [
+ (np.true_divide, math_ops.truediv),
+ (np.floor_divide, math_ops.floordiv),
+ (np.true_divide, _TRUEDIV),
+ (np.floor_divide, _FLOORDIV),
+ ]
+ self._testBCastByFunc(funcs, xs, ys)
+
+ def testBCast_0A(self):
+ self._testBCastA([1, 3, 2], [1])
+
+ def testBCast_0B(self):
+ self._testBCastB([1, 3, 2], [1])
+
+ def testBCast_0C(self):
+ self._testBCastC([1, 3, 2], [1])
+
+ def testBCast_0D(self):
+ self._testBCastD([1, 3, 2], [1])
+
+ def testBCast_1A(self):
+ self._testBCastA([1, 3, 2], [2])
+
+ def testBCast_1B(self):
+ self._testBCastB([1, 3, 2], [2])
+
+ def testBCast_1C(self):
+ self._testBCastC([1, 3, 2], [2])
+
+ def testBCast_1D(self):
+ self._testBCastD([1, 3, 2], [2])
+
+ def testBCast_2A(self):
+ self._testBCastA([1, 3, 2], [3, 2])
+
+ def testBCast_2B(self):
+ self._testBCastB([1, 3, 2], [3, 2])
+
+ def testBCast_2C(self):
+ self._testBCastC([1, 3, 2], [3, 2])
+
+ def testBCast_2D(self):
+ self._testBCastD([1, 3, 2], [3, 2])
+
+ def testBCast_3A(self):
+ self._testBCastA([1, 3, 2], [3, 1])
+
+ def testBCast_3B(self):
+ self._testBCastB([1, 3, 2], [3, 1])
+
+ def testBCast_3C(self):
+ self._testBCastC([1, 3, 2], [3, 1])
+
+ def testBCast_3D(self):
+ self._testBCastD([1, 3, 2], [3, 1])
+
+ def testBCast_4A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 2])
+
+ def testBCast_4D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 2])
+
+ def testBCast_5A(self):
+ self._testBCastA([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5B(self):
+ self._testBCastB([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5C(self):
+ self._testBCastC([1, 3, 2], [2, 3, 1])
+
+ def testBCast_5D(self):
+ self._testBCastD([1, 3, 2], [2, 3, 1])
+
+ def testBCast_6A(self):
+ self._testBCastA([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6B(self):
+ self._testBCastB([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6C(self):
+ self._testBCastC([1, 3, 2], [2, 1, 1])
+
+ def testBCast_6D(self):
+ self._testBCastD([1, 3, 2], [2, 1, 1])
+
+ def testBCast_7A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 1])
+
+ def testBCast_7D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 1])
+
+ def testBCast_8A(self):
+ self._testBCastA([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8B(self):
+ self._testBCastB([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8C(self):
+ self._testBCastC([2, 1, 5], [2, 3, 1])
+
+ def testBCast_8D(self):
+ self._testBCastD([2, 1, 5], [2, 3, 1])
+
+ def testBCast_9A(self):
+ self._testBCastA([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9B(self):
+ self._testBCastB([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9C(self):
+ self._testBCastC([2, 0, 5], [2, 0, 1])
+
+ def testBCast_9D(self):
+ self._testBCastD([2, 0, 5], [2, 0, 1])
+
+ def testBCast_10A(self):
+ self._testBCastA([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10B(self):
+ self._testBCastB([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10C(self):
+ self._testBCastC([2, 3, 0], [2, 3, 1])
+
+ def testBCast_10D(self):
+ self._testBCastD([2, 3, 0], [2, 3, 1])
+
+ def testBCast_11A(self):
+ self._testBCastA([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11B(self):
+ self._testBCastB([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11C(self):
+ self._testBCastC([1, 3, 2], [1, 3, 2])
+
+ def testBCast_11D(self):
+ self._testBCastD([1, 3, 2], [1, 3, 2])
+
+ def testBCast_12A(self):
+ self._testBCastA([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12B(self):
+ self._testBCastB([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12C(self):
+ self._testBCastC([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_12D(self):
+ self._testBCastD([1, 1, 1, 1, 3, 2], [1, 3, 2])
+
+ def testBCast_13A(self):
+ self._testBCastA([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13B(self):
+ self._testBCastB([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13C(self):
+ self._testBCastC([1, 3, 2, 1, 1], [1])
+
+ def testBCast_13D(self):
+ self._testBCastD([1, 3, 2, 1, 1], [1])
+
+ def testBCast_14A(self):
+ self._testBCastA([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14B(self):
+ self._testBCastB([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14C(self):
+ self._testBCastC([2, 3, 1, 1, 5], [1])
+
+ def testBCast_14D(self):
+ self._testBCastD([2, 3, 1, 1, 5], [1])
+
+ def testBCast_15A(self):
+ self._testBCastA([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15B(self):
+ self._testBCastB([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15C(self):
+ self._testBCastC([10, 3, 1, 2], [3, 1, 2])
+
+ def testBCast_15D(self):
+ self._testBCastD([10, 3, 1, 2], [3, 1, 2])
+
+ def testMismatchedDimensions(self):
+ for func in [
+ math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div, _ADD,
+ _SUB, _MUL, _TRUEDIV, _FLOORDIV
+ ]:
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Dimensions must" in str(e)):
+ func(
+ ops.convert_to_tensor([10.0, 20.0, 30.0]),
+ ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
+
+ def testZeroPowGrad(self):
+ with self.cached_session():
+ for dtype in (np.float16, np.float32, np.float64, np.complex64,
+ np.complex128):
+ x = constant_op.constant(0.0, dtype=dtype)
+ y = constant_op.constant(2.0, dtype=dtype)
+ z = math_ops.pow(x, y)
+ error = gradient_checker.compute_gradient_error(y, [], z, [])
+ self.assertEqual(error, 0)
+
+ def testComplexPowGrad(self):
+ with self.cached_session():
+ for dtype in np.complex64, np.complex128:
+ for base in 2.0, -2.0:
+ x = constant_op.constant(base, dtype=dtype)
+ y = constant_op.constant(2.0, dtype=dtype)
+ z = math_ops.pow(x, y)
+ error = gradient_checker.compute_gradient_error(y, [], z, [])
+ self.assertLess(error, 2e-4)
+
+ def testAtan2SpecialValues(self):
+ x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
+ (1.2345, float("inf")), (1.2345, -float("inf")),
+ (-4.321, float("inf")), (-4.125, -float("inf")),
+ (float("inf"), float("inf")), (float("inf"), -float("inf")),
+ (-float("inf"), float("inf")),
+ (-float("inf"), -float("inf")))
+ for dtype in np.float32, np.float64:
+ x1 = np.array(x1l).astype(dtype)
+ x2 = np.array(x2l).astype(dtype)
+ self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
+ self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
+
+ def testPowNegativeExponent(self):
+ for dtype in [np.int32, np.int64]:
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = np.array([-2, 3]).astype(dtype)
+ sess.run(math_ops.pow(x, y))
+
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = np.array([2, -3]).astype(dtype)
+ sess.run(math_ops.pow(x, y))
+
+ with self.test_session(use_gpu=False) as sess:
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ "Integers to negative integer powers are not allowed"):
+ x = np.array([5, 2]).astype(dtype)
+ y = -3
+ sess.run(math_ops.pow(x, y))
+
+
+class ComparisonOpTest(test.TestCase):
+
+ def _compareScalar(self, func, x, y, dtype):
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ out = func(
+ ops.convert_to_tensor(np.array([x]).astype(dtype)),
+ ops.convert_to_tensor(np.array([y]).astype(dtype)))
+ ret = out.eval()
+ return ret[0]
+
+ def testScalarCompareScalar(self):
+ dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
+ data = [-1, 0, 1]
+ for t in dtypes:
+ for x in data:
+ for y in data:
+ self.assertEqual(self._compareScalar(math_ops.less, x, y, t), x < y)
+ self.assertEqual(
+ self._compareScalar(math_ops.less_equal, x, y, t), x <= y)
+ self.assertEqual(
+ self._compareScalar(math_ops.greater, x, y, t), x > y)
+ self.assertEqual(
+ self._compareScalar(math_ops.greater_equal, x, y, t), x >= y)
+ self.assertEqual(self._compareScalar(math_ops.equal, x, y, t), x == y)
+ self.assertEqual(
+ self._compareScalar(math_ops.not_equal, x, y, t), x != y)
+ data = [-1, 0, 1, -1j, 1j, 1 + 1j, 1 - 1j]
+ for t in [np.complex64, np.complex128]:
+ for x in data:
+ for y in data:
+ self.assertEqual(self._compareScalar(math_ops.equal, x, y, t), x == y)
+ self.assertEqual(
+ self._compareScalar(math_ops.not_equal, x, y, t), x != y)
+
+ def _compare(self, x, y, np_func, tf_func):
+ np_ans = np_func(x, y)
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ out = tf_func(ops.convert_to_tensor(x), ops.convert_to_tensor(y))
+ tf_ans = out.eval()
+ self.assertAllEqual(np_ans, tf_ans)
+
+ def testTensorCompareTensor(self):
+ x = np.linspace(-15, 15, 6).reshape(1, 3, 2)
+ y = np.linspace(20, -10, 6).reshape(1, 3, 2)
+ for t in [np.float16, np.float32, np.float64, np.int32, np.int64]:
+ xt = x.astype(t)
+ yt = y.astype(t)
+ self._compare(xt, yt, np.less, math_ops.less)
+ self._compare(xt, yt, np.less_equal, math_ops.less_equal)
+ self._compare(xt, yt, np.greater, math_ops.greater)
+ self._compare(xt, yt, np.greater_equal, math_ops.greater_equal)
+ self._compare(xt, yt, np.equal, math_ops.equal)
+ self._compare(xt, yt, np.not_equal, math_ops.not_equal)
+ # Complex types do not support ordering but do support equality tests.
+ for t in [np.complex64, np.complex128]:
+ xt = x.astype(t)
+ xt -= 1j * xt
+ yt = y.astype(t)
+ yt -= 1j * yt
+ self._compare(xt, yt, np.equal, math_ops.equal)
+ self._compare(xt, yt, np.not_equal, math_ops.not_equal)
+
+ def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
+ x = np.linspace(-15, 15, np.prod(xs)).astype(dtype).reshape(xs)
+ y = np.linspace(20, -10, np.prod(ys)).astype(dtype).reshape(ys)
+ if dtype in (np.complex64, np.complex128):
+ x -= 1j * x
+ y -= 1j * y
+ self._compare(x, y, np_func, tf_func)
+ self._compare(y, x, np_func, tf_func)
+
+ def _testBCastByFunc(self, np_func, tf_func, include_complex=False):
+ shapes = [
+ ([1, 3, 2], [1]),
+ ([1, 3, 2], [2]),
+ ([1, 3, 2], [3, 2]),
+ ([1, 3, 2], [3, 1]),
+ ([1, 3, 2], [1, 3, 2]),
+ ([1, 3, 2], [2, 3, 1]),
+ ([1, 3, 2], [2, 1, 1]),
+ ([1, 3, 2], [1, 3, 1]),
+ ([2, 1, 5], [2, 3, 1]),
+ ([2, 0, 5], [2, 0, 1]),
+ ([2, 3, 0], [2, 3, 1]),
+ ]
+ dtypes = [
+ np.float16,
+ np.float32,
+ np.float64,
+ np.int32,
+ np.int64,
+ ]
+ if include_complex:
+ dtypes.extend([np.complex64, np.complex128])
+
+ for (xs, ys) in shapes:
+ for dtype in dtypes:
+ self._compareBCast(xs, ys, dtype, np_func, tf_func)
+
+ def testBCastLess(self):
+ self._testBCastByFunc(np.less, math_ops.less)
+
+ def testBCastLessEqual(self):
+ self._testBCastByFunc(np.less_equal, math_ops.less_equal)
+
+ def testBCastGreater(self):
+ self._testBCastByFunc(np.greater, math_ops.greater)
+
+ def testBCastGreaterEqual(self):
+ self._testBCastByFunc(np.greater_equal, math_ops.greater_equal)
+
+ def testBCastEqual(self):
+ self._testBCastByFunc(np.equal, math_ops.equal, include_complex=True)
+
+ def testBCastNotEqual(self):
+ self._testBCastByFunc(
+ np.not_equal, math_ops.not_equal, include_complex=True)
+
+ def testShapeMismatch(self):
+ dtypes = [np.float16, np.float32, np.float64, np.int32, np.int64]
+ funcs = [
+ math_ops.less, math_ops.less_equal, math_ops.greater,
+ math_ops.greater_equal, math_ops.equal, math_ops.not_equal
+ ]
+ x = np.arange(0, 10).reshape([2, 5])
+ y = np.arange(0, 10).reshape([5, 2])
+ for t in dtypes:
+ for f in funcs:
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: "Dimensions must" in str(e)):
+ f(x.astype(t), y.astype(t))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index b61232cded..c5311ad834 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -18,25 +18,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import math
-
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
-from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gradient_checker
-from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
_ADD = lambda x, y: x + y
_SUB = lambda x, y: x - y
@@ -45,8 +39,6 @@ _POW = lambda x, y: x**y
_TRUEDIV = lambda x, y: x / y
_FLOORDIV = lambda x, y: x // y
_MOD = lambda x, y: x % y
-_NEG = lambda x: -x
-_ABS = abs
_LT = lambda x, y: x < y
_LE = lambda x, y: x <= y
@@ -74,8 +66,11 @@ def _sparsify(x, thresh=0.5, index_dtype=np.int64):
def _default_tolerance(dtype):
- """Returns a sensible default tolerance for comparing results of a given
- type"""
+ """Returns a sensible default tolerance for comparing results of a given type.
+
+ Args:
+ dtype: A datatype.
+ """
if dtype == np.float16:
return 5e-3
elif dtype in (np.float32, np.complex64):
@@ -86,1147 +81,6 @@ def _default_tolerance(dtype):
return None # Fail fast for unexpected types
-class UnaryOpTest(test.TestCase):
-
- def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
- if grad_rtol is None:
- grad_rtol = _default_tolerance(x.dtype)
- if grad_atol is None:
- grad_atol = _default_tolerance(x.dtype)
- np_ans = np_func(x)
- with self.test_session(use_gpu=False):
- inx = ops.convert_to_tensor(x)
- if x.dtype in (np.float32, np.float64,
- dtypes_lib.bfloat16.as_numpy_dtype):
- y = 1.1 * tf_func(inx)
- np_ans *= 1.1
- else:
- y = tf_func(inx)
- tf_cpu = y.eval()
- self.assertShapeEqual(np_ans, y)
- if x.dtype == np.float16:
- self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
- elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
- self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
- else:
- self.assertAllClose(np_ans, tf_cpu)
-
- if x.dtype in (np.complex64, np.complex128) and tf_func == math_ops.sign:
- return # Return early
-
- if x.dtype == np.float16:
- s = list(np.shape(x))
- jacob_t, _ = gradient_checker.compute_gradient(
- inx, s, y, s, x_init_value=x)
- xf = x.astype(np.float)
- inxf = ops.convert_to_tensor(xf)
- yf = tf_func(inxf)
- _, jacob_n = gradient_checker.compute_gradient(
- inxf, s, yf, s, x_init_value=xf, delta=1e-2)
- jacob_n = jacob_n.astype(np.float16)
- self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
- elif x.dtype in (np.float32, np.complex64):
- s = list(np.shape(x))
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- inx, s, y, s, x_init_value=x, delta=1e-3)
- self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
- elif x.dtype in (np.float64, np.complex128):
- s = list(np.shape(x))
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- inx, s, y, s, x_init_value=x, delta=1e-5)
- self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
-
- def _check(self, result_tensor, result_np, input_sp_t, tol):
- self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
- self.assertTrue(isinstance(input_sp_t, sparse_tensor.SparseTensor))
- self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
- self.assertAllEqual(input_sp_t.dense_shape.eval(),
- result_tensor.dense_shape.eval())
- if tol is None:
- self.assertAllClose(result_np, result_tensor.values.eval())
- else:
- self.assertAllClose(
- result_np, result_tensor.values.eval(), rtol=tol, atol=tol)
-
- def _compareSparseCpu(self, x, np_func, tf_func, tol):
- x_sp, x_sp_vals = _sparsify(x)
- res_np = np_func(x_sp_vals)
- with self.test_session(use_gpu=False):
- self._check(tf_func(x_sp), res_np, x_sp, tol)
-
- def _compareGpu(self, x, np_func, tf_func):
- np_ans = np_func(x)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
- result = tf_func(ops.convert_to_tensor(x))
- tf_gpu = result.eval()
- if x.dtype == np.float16:
- self.assertAllClose(np_ans, tf_gpu, rtol=1e-3, atol=1e-3)
- else:
- self.assertAllClose(np_ans, tf_gpu)
- # TODO(zhifengc/ke): make gradient checker work on GPU.
-
- def _compareSparseGpu(self, x, np_func, tf_func, tol):
- x_sp, x_sp_vals = _sparsify(x)
- res_np = np_func(x_sp_vals)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
- self._check(tf_func(x_sp), res_np, x_sp, tol)
-
- def _compareBoth(self, x, np_func, tf_func):
- self._compareCpu(x, np_func, tf_func)
- self._compareGpu(x, np_func, tf_func)
-
- def _compareBothSparse(self, x, np_func, tf_func, tol=None):
- self._compareSparseCpu(x, np_func, tf_func, tol)
- self._compareSparseGpu(x, np_func, tf_func, tol)
-
- def _inv(self, x):
- return 1.0 / x
-
- def _rsqrt(self, x):
- return self._inv(np.sqrt(x))
-
- def _sigmoid(self, x):
- return 1.0 / (1.0 + np.exp(-x))
-
- def _log_sigmoid(self, x):
- return np.log(self._sigmoid(x))
-
- def _replace_domain_error_with_inf(self, fn):
-
- def func(x):
- try:
- return fn(x)
- except ValueError as e:
- if "domain error" in str(e):
- return np.inf * np.ones_like(x)
- else:
- raise e
-
- return func
-
- def testFloatBasic(self):
- x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
- w = x - x.min() + 1.02 # all greater than 1
- y = (x + .5).astype(np.float32) # no zero
- z = (x + 15.5).astype(np.float32) # all positive
- k = np.arange(-0.90, 0.90, 0.25).astype(np.float32) # between -1 and 1
-
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, math_ops.reciprocal)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareBoth(z, np.sqrt, math_ops.sqrt)
- self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareBoth(x, np.expm1, math_ops.expm1)
- self._compareBoth(z, np.log, math_ops.log)
- self._compareBoth(z, np.log1p, math_ops.log1p)
- self._compareBoth(x, np.sinh, math_ops.sinh)
- self._compareBoth(x, np.cosh, math_ops.cosh)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- self._compareBoth(x, np.arcsinh, math_ops.asinh)
- self._compareBoth(w, np.arccosh, math_ops.acosh)
- self._compareBoth(k, np.arctanh, math_ops.atanh)
- self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
- self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid)
- self._compareBoth(y, np.sign, math_ops.sign)
- self._compareBoth(x, np.sin, math_ops.sin)
- self._compareBoth(x, np.cos, math_ops.cos)
- self._compareBoth(k, np.arcsin, math_ops.asin)
- self._compareBoth(k, np.arccos, math_ops.acos)
- self._compareBoth(x, np.arctan, math_ops.atan)
- self._compareBoth(x, np.tan, math_ops.tan)
- self._compareBoth(y,
- np.vectorize(
- self._replace_domain_error_with_inf(math.lgamma)),
- math_ops.lgamma)
- self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
- self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
- self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
- self._compareBothSparse(y, np.sign, math_ops.sign)
- self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
-
- def testFloatTanhEdge(self):
- x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- x = np.arange(-40, -40 + 6).reshape(6).astype(np.float32)
- self._compareBoth(x, np.tanh, math_ops.tanh)
-
- def testFloatEmpty(self):
- x = np.empty((2, 0, 5), dtype=np.float32)
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(x, self._inv, math_ops.reciprocal)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareBoth(x, np.sqrt, math_ops.sqrt)
- self._compareBoth(x, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareBoth(x, np.expm1, math_ops.expm1)
- self._compareBoth(x, np.log, math_ops.log)
- self._compareBoth(x, np.log1p, math_ops.log1p)
- self._compareBoth(x, np.sinh, math_ops.sinh)
- self._compareBoth(x, np.arcsinh, math_ops.asinh)
- self._compareBoth(x, np.cosh, math_ops.cosh)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
- self._compareBoth(x, np.sign, math_ops.sign)
- self._compareBoth(x, np.sin, math_ops.sin)
- self._compareBoth(x, np.cos, math_ops.cos)
- # Can't use vectorize below, so just use some arbitrary function
- self._compareBoth(x, np.sign, math_ops.lgamma)
- self._compareBoth(x, np.sign, math_ops.erf)
- self._compareBoth(x, np.sign, math_ops.erfc)
- self._compareBoth(x, np.tan, math_ops.tan)
- self._compareBoth(x, np.arcsin, math_ops.asin)
- self._compareBoth(x, np.arccos, math_ops.acos)
- self._compareBoth(x, np.arctan, math_ops.atan)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
- self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.sqrt, math_ops.sqrt, tol=1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
- self._compareBothSparse(x, np.sign, math_ops.sign)
- self._compareBothSparse(x, np.sign, math_ops.erf)
-
- def testDoubleBasic(self):
- x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
- w = x - x.min() + 1.02 # all greater than 1
- y = (x + .5).astype(np.float64) # no zero
- z = (x + 15.5).astype(np.float64) # all positive
- k = np.arange(-0.90, 0.90,
- 0.35).reshape(1, 3, 2).astype(np.float64) # between -1 and 1
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, math_ops.reciprocal)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareBoth(z, np.sqrt, math_ops.sqrt)
- self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareBoth(x, np.expm1, math_ops.expm1)
- self._compareBoth(z, np.log, math_ops.log)
- self._compareBoth(z, np.log1p, math_ops.log1p)
- self._compareBoth(x, np.sinh, math_ops.sinh)
- self._compareBoth(x, np.cosh, math_ops.cosh)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- self._compareBoth(x, np.arcsinh, math_ops.asinh)
- self._compareBoth(w, np.arccosh, math_ops.acosh)
- self._compareBoth(k, np.arctanh, math_ops.atanh)
- self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
- self._compareBoth(y, np.sign, math_ops.sign)
- self._compareBoth(x, np.sin, math_ops.sin)
- self._compareBoth(x, np.cos, math_ops.cos)
- self._compareBoth(y,
- np.vectorize(
- self._replace_domain_error_with_inf(math.lgamma)),
- math_ops.lgamma)
- self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
- self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
- self._compareBoth(x, np.arctan, math_ops.atan)
- self._compareBoth(k, np.arcsin, math_ops.asin)
- self._compareBoth(k, np.arccos, math_ops.acos)
- self._compareBoth(k, np.tan, math_ops.tan)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
- self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
- self._compareBothSparse(y, np.sign, math_ops.sign)
- self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
-
- def testHalfBasic(self):
- x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
- y = (x + .5).astype(np.float16) # no zero
- z = (x + 15.5).astype(np.float16) # all positive
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(y, self._inv, math_ops.reciprocal)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareBoth(z, np.sqrt, math_ops.sqrt)
- self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareBoth(x, np.expm1, math_ops.expm1)
- self._compareBoth(z, np.log, math_ops.log)
- self._compareBoth(z, np.log1p, math_ops.log1p)
- self._compareBoth(x, np.tanh, math_ops.tanh)
- self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
- self._compareBoth(y, np.sign, math_ops.sign)
- self._compareBoth(x, np.sin, math_ops.sin)
- self._compareBoth(x, np.cos, math_ops.cos)
- self._compareBoth(y,
- np.vectorize(
- self._replace_domain_error_with_inf(math.lgamma)),
- math_ops.lgamma)
- self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
- self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
- self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
- self._compareBothSparse(y, np.sign, math_ops.sign)
- self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf, tol=1e-3)
-
- def testInt32Basic(self):
- x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
- self._compareCpu(x, np.abs, math_ops.abs)
- self._compareCpu(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareBoth(x, np.square, math_ops.square)
- self._compareCpu(x, np.sign, math_ops.sign)
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.sign, math_ops.sign)
-
- def testInt64Basic(self):
- x = np.arange(-6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
- self._compareCpu(x, np.abs, math_ops.abs)
- self._compareCpu(x, np.abs, _ABS)
- self._compareCpu(x, np.negative, math_ops.negative)
- self._compareCpu(x, np.negative, _NEG)
- self._compareCpu(x, np.sign, math_ops.sign)
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.sign, math_ops.sign)
-
- def testInt64Square(self):
- x = np.arange(-6 << 20, 6 << 20, 2 << 20).reshape(1, 3, 2).astype(np.int64)
- self._compareCpu(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.square, math_ops.square)
-
- def testComplex64Basic(self):
- x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
- np.complex64)
- y = x + np.complex(0.5, 0.5) # no zeros
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareCpu(y, self._inv, math_ops.reciprocal)
- self._compareCpu(x, np.square, math_ops.square)
- self._compareCpu(y, np.sqrt, math_ops.sqrt)
- self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareCpu(x, np.expm1, math_ops.expm1)
- self._compareCpu(y, np.log, math_ops.log)
- self._compareCpu(y, np.log1p, math_ops.log1p)
- self._compareCpu(x, np.sinh, math_ops.sinh)
- self._compareCpu(x, np.cosh, math_ops.cosh)
- self._compareCpu(x, np.tanh, math_ops.tanh)
-
- # Complex64 versions of asinh() and acosh() in libstdc++ only have 6 digits
- # of precision.
- # Small gradient values + low precision --> High relative error
- self._compareCpu(y, np.arcsinh, math_ops.asinh, grad_rtol=1e-2)
- self._compareCpu(y, np.arccosh, math_ops.acosh, grad_rtol=1e-2)
-
- self._compareCpu(y, np.arctanh, math_ops.atanh)
- self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
- self._compareCpu(x, np.sin, math_ops.sin)
- self._compareCpu(x, np.cos, math_ops.cos)
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
-
- # Numpy uses an incorrect definition of sign; use the right one instead.
- def complex_sign(x):
- return x / np.abs(x)
-
- self._compareBoth(y, complex_sign, math_ops.sign)
- self._compareBothSparse(y, complex_sign, math_ops.sign)
-
- def testComplex128Basic(self):
- x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
- np.complex128)
- y = x + np.complex(0.5, 0.5) # no zeros
- self._compareBoth(x, np.abs, math_ops.abs)
- self._compareBoth(x, np.abs, _ABS)
- self._compareBoth(x, np.negative, math_ops.negative)
- self._compareBoth(x, np.negative, _NEG)
- self._compareCpu(y, self._inv, math_ops.reciprocal)
- self._compareCpu(x, np.square, math_ops.square)
- self._compareCpu(y, np.sqrt, math_ops.sqrt)
- self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
- self._compareBoth(x, np.exp, math_ops.exp)
- self._compareCpu(x, np.expm1, math_ops.expm1)
- self._compareCpu(y, np.log, math_ops.log)
- self._compareCpu(y, np.log1p, math_ops.log1p)
- self._compareCpu(x, np.sinh, math_ops.sinh)
- self._compareCpu(x, np.cosh, math_ops.cosh)
- self._compareCpu(x, np.tanh, math_ops.tanh)
- self._compareCpu(y, np.arcsinh, math_ops.asinh)
- self._compareCpu(y, np.arccosh, math_ops.acosh)
- self._compareCpu(y, np.arctanh, math_ops.atanh)
- self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
- self._compareCpu(x, np.sin, math_ops.sin)
- self._compareCpu(x, np.cos, math_ops.cos)
-
- self._compareBothSparse(x, np.abs, math_ops.abs)
- self._compareBothSparse(x, np.negative, math_ops.negative)
- self._compareBothSparse(x, np.square, math_ops.square)
- self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
- self._compareBothSparse(x, np.tanh, math_ops.tanh)
-
- # Numpy uses an incorrect definition of sign; use the right one instead.
- def complex_sign(x):
- return x / np.abs(x)
-
- self._compareBoth(y, complex_sign, math_ops.sign)
- self._compareBothSparse(y, complex_sign, math_ops.sign)
-
- def testGradGrad(self):
- np.random.seed(7)
- shape = (5,)
- dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4),
- (np.complex128, 1e-6)]
- op_range = [
- (gen_math_ops.reciprocal_grad, [-2, 2]),
- (gen_math_ops.rsqrt_grad, [0.1, 3]),
- (gen_math_ops.sigmoid_grad, [-2, 2]),
- (gen_math_ops.sqrt_grad, [0.1, 3]),
- (gen_math_ops.tanh_grad, [-2, 2]),
- ]
-
- def rand(dtype):
- x = np.random.uniform(
- real_range[0], real_range[1], size=shape[0]).astype(dtype)
- if dtype in (np.complex64, np.complex128):
- x += 1j * np.random.uniform(-2, 2, size=shape[0]).astype(dtype)
- return x
-
- for op, real_range in op_range:
- with self.test_session():
- for dtype, tol in dtype_tols:
- x = constant_op.constant(rand(dtype))
- y = constant_op.constant(rand(dtype))
- z = op(x, y)
- grads = gradient_checker.compute_gradient(
- [x, y], [shape, shape],
- z,
- shape,
- x_init_value=[rand(dtype), rand(dtype)])
- if isinstance(grads, tuple):
- grads = [grads]
- for analytical, numerical in grads:
- self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
-
-
-class BinaryOpTest(test.TestCase):
-
- def _compareCpu(self, x, y, np_func, tf_func, also_compare_variables=False):
- np_ans = np_func(x, y)
- with self.test_session(use_gpu=False):
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- out = tf_func(inx, iny)
- tf_cpu = out.eval()
- # Test that the op takes precedence over numpy operators.
- np_left = tf_func(x, iny).eval()
- np_right = tf_func(inx, y).eval()
-
- if also_compare_variables:
- var_x = variables.Variable(x)
- var_y = variables.Variable(y)
- variables.global_variables_initializer().run()
- print(type(x), type(y), type(var_x), type(var_y))
- print(type(tf_func(x, var_y)), type(tf_func(var_x, y)))
- np_var_left = tf_func(x, var_y).eval()
- np_var_right = tf_func(var_x, y).eval()
-
- if np_ans.dtype != np.object:
- self.assertAllClose(np_ans, tf_cpu)
- self.assertAllClose(np_ans, np_left)
- self.assertAllClose(np_ans, np_right)
- if also_compare_variables:
- self.assertAllClose(np_ans, np_var_left)
- self.assertAllClose(np_ans, np_var_right)
- self.assertShapeEqual(np_ans, out)
-
- _GRAD_TOL = {
- dtypes_lib.float16: 1e-3,
- dtypes_lib.float32: 1e-3,
- dtypes_lib.complex64: 1e-2,
- dtypes_lib.float64: 1e-5,
- dtypes_lib.complex128: 1e-4
- }
-
- def _compareGradientX(self,
- x,
- y,
- np_func,
- tf_func,
- numeric_gradient_type=None):
- z = np_func(x, y)
- zs = list(z.shape)
- with self.test_session():
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- if x.dtype in (np.float32, np.float64):
- out = 1.1 * tf_func(inx, iny)
- else:
- out = tf_func(inx, iny)
- xs = list(x.shape)
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- inx, xs, out, zs, x_init_value=x)
- if numeric_gradient_type is not None:
- xf = x.astype(numeric_gradient_type)
- yf = y.astype(numeric_gradient_type)
- inxf = ops.convert_to_tensor(xf)
- inyf = ops.convert_to_tensor(yf)
- outf = tf_func(inxf, inyf)
- _, jacob_n = gradient_checker.compute_gradient(
- inxf, xs, outf, zs, x_init_value=xf, delta=1e-3)
- jacob_n = jacob_n.astype(x.dtype)
- tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
- self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
-
- def _compareGradientY(self,
- x,
- y,
- np_func,
- tf_func,
- numeric_gradient_type=None):
- z = np_func(x, y)
- zs = list(z.shape)
- with self.test_session():
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- if x.dtype in (np.float32, np.float64):
- out = 1.1 * tf_func(inx, iny)
- else:
- out = tf_func(inx, iny)
- ys = list(np.shape(y))
- jacob_t, jacob_n = gradient_checker.compute_gradient(
- iny, ys, out, zs, x_init_value=y)
- if numeric_gradient_type is not None:
- xf = x.astype(numeric_gradient_type)
- yf = y.astype(numeric_gradient_type)
- inxf = ops.convert_to_tensor(xf)
- inyf = ops.convert_to_tensor(yf)
- outf = tf_func(inxf, inyf)
- _, jacob_n = gradient_checker.compute_gradient(
- inyf, ys, outf, zs, x_init_value=yf)
- jacob_n = jacob_n.astype(x.dtype)
- tol = self._GRAD_TOL[dtypes_lib.as_dtype(x.dtype)]
- self.assertAllClose(jacob_t, jacob_n, rtol=tol, atol=tol)
-
- def _compareGpu(self, x, y, np_func, tf_func):
- np_ans = np_func(x, y)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- out = tf_func(inx, iny)
- tf_gpu = out.eval()
- self.assertAllClose(np_ans, tf_gpu)
- self.assertShapeEqual(np_ans, out)
- # TODO(zhifengc/ke): make gradient checker work on GPU.
-
- def _compareBoth(self, x, y, np_func, tf_func, also_compare_variables=False):
- self._compareCpu(x, y, np_func, tf_func, also_compare_variables)
- if x.dtype in (np.float16, np.float32, np.float64, np.complex64,
- np.complex128):
- if tf_func not in (_FLOORDIV, math_ops.floordiv, math_ops.zeta,
- math_ops.polygamma):
- self._compareGradientX(x, y, np_func, tf_func)
- self._compareGradientY(x, y, np_func, tf_func)
- if tf_func in (math_ops.zeta, math_ops.polygamma):
- # These methods only support gradients in the second parameter
- self._compareGradientY(x, y, np_func, tf_func)
- self._compareGpu(x, y, np_func, tf_func)
-
- def testFloatBasic(self):
- x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float32)
- y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float32)
- self._compareBoth(x, y, np.add, math_ops.add, also_compare_variables=True)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
- self._compareBoth(x, y, np.arctan2, math_ops.atan2)
- x1 = np.random.randn(5, 6).astype(np.float32)
- x2 = np.random.randn(5, 6).astype(np.float32)
- # Remove tiny values--atan2 gradients are flaky near the origin.
- x1[np.abs(x1) < 0.05] = 0.05 * np.sign(x1[np.abs(x1) < 0.05])
- x2[np.abs(x2) < 0.05] = 0.05 * np.sign(x2[np.abs(x2) < 0.05])
- self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
- x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
- self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
- math_ops.igamma)
- self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
- math_ops.igammac)
- # Need x > 1
- self._compareBoth(x_pos_small + 1, a_pos_small, special.zeta,
- math_ops.zeta)
- n_small = np.arange(0, 15).reshape(1, 3, 5).astype(np.float32)
- self._compareBoth(n_small, x_pos_small, special.polygamma,
- math_ops.polygamma)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- def testFloatDifferentShapes(self):
- x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.float32)
- y = np.array([1, 2]).reshape(2, 1).astype(np.float32)
- with self.test_session() as sess:
- inx = ops.convert_to_tensor(x)
- iny = ops.convert_to_tensor(y)
- s = math_ops.reduce_sum(inx * iny)
- gx, gy = sess.run(gradients_impl.gradients(s, [inx, iny]))
- # gx is simply the broadcasted y
- self.assertAllEqual(gx,
- np.array([1, 1, 2, 2]).reshape(2, 2).astype(np.float32))
- # gy is x's column summed up
- self.assertAllEqual(gy, np.array([3, 7]).reshape(2, 1).astype(np.float32))
-
- def testFloatVariableOverload(self):
- x = np.array([1, 2, 3, 4]).reshape(2, 2).astype(np.int32)
- y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
- var_x = variables.Variable(x)
- var_y = variables.Variable(y)
- with self.test_session() as sess:
- sess.run([var_x.initializer, var_y.initializer])
- left_result = (var_x * y).eval()
- right_result = (x * var_y).eval()
- np_result = x * y
- self.assertAllEqual(np_result, left_result)
- self.assertAllEqual(np_result, right_result)
-
- def testDoubleBasic(self):
- x = np.linspace(-5, 20, 15).reshape(1, 3, 5).astype(np.float64)
- y = np.linspace(20, -5, 15).reshape(1, 3, 5).astype(np.float64)
- self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y + 0.1, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y + 0.1, np.floor_divide, _FLOORDIV)
- self._compareBoth(x, y, np.arctan2, math_ops.atan2)
- x1 = np.random.randn(7, 4).astype(np.float64)
- x2 = np.random.randn(7, 4).astype(np.float64)
- # Remove tiny values--atan2 gradients are flaky near the origin.
- x1[np.abs(x1) < 0.5] = 0.5 * np.sign(x1[np.abs(x1) < 0.5])
- x2[np.abs(x2) < 0.5] = 0.5 * np.sign(x2[np.abs(x2) < 0.5])
- self._compareBoth(x1, x2, np.arctan2, math_ops.atan2)
- try:
- from scipy import special # pylint: disable=g-import-not-at-top
- a_pos_small = np.linspace(0.1, 2, 15).reshape(1, 3, 5).astype(np.float32)
- x_pos_small = np.linspace(0.1, 10, 15).reshape(1, 3, 5).astype(np.float32)
- self._compareBoth(a_pos_small, x_pos_small, special.gammainc,
- math_ops.igamma)
- self._compareBoth(a_pos_small, x_pos_small, special.gammaincc,
- math_ops.igammac)
- except ImportError as e:
- tf_logging.warn("Cannot test special functions: %s" % str(e))
-
- def testUint8Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint8)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint8)
- self._compareBoth(x, y, np.add, math_ops.add)
-
- def testInt8Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int8)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int8)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.multiply, _MUL)
-
- def testInt16Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int16)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int16)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.multiply, _MUL)
-
- def testUint16Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.uint16)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.uint16)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
-
- def testInt32Basic(self):
- x = np.arange(1, 13, 2).reshape(1, 3, 2).astype(np.int32)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int32)
- self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.mod, math_ops.mod)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
- self._compareBoth(x, y, np.mod, _MOD)
- # _compareBoth tests on GPU only for floating point types, so test
- # _MOD for int32 on GPU by calling _compareGpu
- self._compareGpu(x, y, np.mod, _MOD)
-
- def testInt64Basic(self):
- x = np.arange(1 << 40, 13 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
- y = np.arange(1, 7, 1).reshape(1, 3, 2).astype(np.int64)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.floor_divide, math_ops.floordiv)
- self._compareBoth(x, y, np.mod, math_ops.mod)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y, np.true_divide, _TRUEDIV)
- self._compareBoth(x, y, np.floor_divide, _FLOORDIV)
- self._compareBoth(x, y, np.mod, _MOD)
-
- def testComplex64Basic(self):
- x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
- np.complex64)
- y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
- np.complex64)
- self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
-
- def testComplex128Basic(self):
- x = np.complex(1, 1) * np.linspace(-10, 10, 6).reshape(1, 3, 2).astype(
- np.complex128)
- y = np.complex(1, 1) * np.linspace(20, -20, 6).reshape(1, 3, 2).astype(
- np.complex128)
- self._compareBoth(x, y, np.add, math_ops.add)
- self._compareBoth(x, y, np.subtract, math_ops.subtract)
- self._compareBoth(x, y, np.multiply, math_ops.multiply)
- self._compareBoth(x, y + 0.1, np.true_divide, math_ops.truediv)
- self._compareBoth(x, y, np.add, _ADD)
- self._compareBoth(x, y, np.subtract, _SUB)
- self._compareBoth(x, y, np.multiply, _MUL)
- self._compareBoth(x, y + 0.1, np.true_divide, _TRUEDIV)
-
- def testStringComparison(self):
- x = np.array([["abc", "bh"], ["c", ""]])
- y = np.array([["abc", "bh"], ["def", "hi"]])
- with self.test_session(use_gpu=False) as sess:
- cmp_eq = math_ops.equal(x, y)
- cmp_not_eq = math_ops.not_equal(x, y)
- values = sess.run([cmp_eq, cmp_not_eq])
- self.assertAllEqual([[True, True], [False, False]], values[0])
- self.assertAllEqual([[False, False], [True, True]], values[1])
-
- def testString(self):
- x = np.array(
- [["x_0_0", "x_0_1", "x_0_2"], ["x_1_0", "x_1_1", "x_1_2"],
- ["x_2_0", "x_2_1", "x_2_2"]],
- dtype=np.object)
- y = np.array(
- [["y_0_0", "y_0_1", "y_0_2"], ["y_1_0", "y_1_1", "y_1_2"],
- ["y_2_0", "y_2_1", "y_2_2"]],
- dtype=np.object)
- z = np.array([["z_0", "z_1", "z_2"]], dtype=np.object)
- w = np.array("w", dtype=np.object)
- self._compareCpu(x, y, _ADD, _ADD)
- self._compareCpu(x, z, _ADD, _ADD)
- self._compareCpu(x, w, _ADD, _ADD)
- self._compareCpu(z, w, _ADD, _ADD)
-
- def _compareBCast(self, xs, ys, dtype, np_func, tf_func):
- if dtype in (np.complex64, np.complex128):
- x = (1 + np.linspace(0, 2 + 3j, np.prod(xs))).astype(dtype).reshape(xs)
- y = (1 + np.linspace(0, 2 - 2j, np.prod(ys))).astype(dtype).reshape(ys)
- else:
- x = (1 + np.linspace(0, 5, np.prod(xs))).astype(dtype).reshape(xs)
- y = (1 + np.linspace(0, 5, np.prod(ys))).astype(dtype).reshape(ys)
- self._compareCpu(x, y, np_func, tf_func)
- if x.dtype in (np.float16, np.float32, np.float64):
- # TODO(aselle): Make the test work for dtypes:
- # (np.complex64, np.complex128).
- if tf_func not in (_FLOORDIV, math_ops.floordiv):
- if x.dtype == np.float16:
- # Compare fp16 theoretical gradients to fp32 numerical gradients,
- # since fp16 numerical gradients are too imprecise unless great
- # care is taken with choosing the inputs and the delta. This is
- # a weaker check (in particular, it does not test the op itself,
- # only its gradient), but it's much better than nothing.
- self._compareGradientX(x, y, np_func, tf_func, np.float)
- self._compareGradientY(x, y, np_func, tf_func, np.float)
- else:
- self._compareGradientX(x, y, np_func, tf_func)
- self._compareGradientY(x, y, np_func, tf_func)
- self._compareGpu(x, y, np_func, tf_func)
-
- # TODO(josh11b,vrv): Refactor this to use parameterized tests.
- def _testBCastByFunc(self, funcs, xs, ys):
- dtypes = [
- np.float16,
- np.float32,
- np.float64,
- np.int32,
- np.int64,
- np.complex64,
- np.complex128,
- ]
- for dtype in dtypes:
- for (np_func, tf_func) in funcs:
- if (dtype in (np.complex64, np.complex128) and
- tf_func in (_FLOORDIV, math_ops.floordiv)):
- continue # floordiv makes no sense for complex numbers
- self._compareBCast(xs, ys, dtype, np_func, tf_func)
- self._compareBCast(ys, xs, dtype, np_func, tf_func)
-
- def _testBCastA(self, xs, ys):
- funcs = [
- (np.add, math_ops.add),
- (np.add, _ADD),
- ]
- self._testBCastByFunc(funcs, xs, ys)
-
- def _testBCastB(self, xs, ys):
- funcs = [
- (np.subtract, math_ops.subtract),
- (np.subtract, _SUB),
- (np.power, math_ops.pow),
- ]
- self._testBCastByFunc(funcs, xs, ys)
-
- def _testBCastC(self, xs, ys):
- funcs = [
- (np.multiply, math_ops.multiply),
- (np.multiply, _MUL),
- ]
- self._testBCastByFunc(funcs, xs, ys)
-
- def _testBCastD(self, xs, ys):
- funcs = [
- (np.true_divide, math_ops.truediv),
- (np.floor_divide, math_ops.floordiv),
- (np.true_divide, _TRUEDIV),
- (np.floor_divide, _FLOORDIV),
- ]
- self._testBCastByFunc(funcs, xs, ys)
-
- def testBCast_0A(self):
- self._testBCastA([1, 3, 2], [1])
-
- def testBCast_0B(self):
- self._testBCastB([1, 3, 2], [1])
-
- def testBCast_0C(self):
- self._testBCastC([1, 3, 2], [1])
-
- def testBCast_0D(self):
- self._testBCastD([1, 3, 2], [1])
-
- def testBCast_1A(self):
- self._testBCastA([1, 3, 2], [2])
-
- def testBCast_1B(self):
- self._testBCastB([1, 3, 2], [2])
-
- def testBCast_1C(self):
- self._testBCastC([1, 3, 2], [2])
-
- def testBCast_1D(self):
- self._testBCastD([1, 3, 2], [2])
-
- def testBCast_2A(self):
- self._testBCastA([1, 3, 2], [3, 2])
-
- def testBCast_2B(self):
- self._testBCastB([1, 3, 2], [3, 2])
-
- def testBCast_2C(self):
- self._testBCastC([1, 3, 2], [3, 2])
-
- def testBCast_2D(self):
- self._testBCastD([1, 3, 2], [3, 2])
-
- def testBCast_3A(self):
- self._testBCastA([1, 3, 2], [3, 1])
-
- def testBCast_3B(self):
- self._testBCastB([1, 3, 2], [3, 1])
-
- def testBCast_3C(self):
- self._testBCastC([1, 3, 2], [3, 1])
-
- def testBCast_3D(self):
- self._testBCastD([1, 3, 2], [3, 1])
-
- def testBCast_4A(self):
- self._testBCastA([1, 3, 2], [1, 3, 2])
-
- def testBCast_4B(self):
- self._testBCastB([1, 3, 2], [1, 3, 2])
-
- def testBCast_4C(self):
- self._testBCastC([1, 3, 2], [1, 3, 2])
-
- def testBCast_4D(self):
- self._testBCastD([1, 3, 2], [1, 3, 2])
-
- def testBCast_5A(self):
- self._testBCastA([1, 3, 2], [2, 3, 1])
-
- def testBCast_5B(self):
- self._testBCastB([1, 3, 2], [2, 3, 1])
-
- def testBCast_5C(self):
- self._testBCastC([1, 3, 2], [2, 3, 1])
-
- def testBCast_5D(self):
- self._testBCastD([1, 3, 2], [2, 3, 1])
-
- def testBCast_6A(self):
- self._testBCastA([1, 3, 2], [2, 1, 1])
-
- def testBCast_6B(self):
- self._testBCastB([1, 3, 2], [2, 1, 1])
-
- def testBCast_6C(self):
- self._testBCastC([1, 3, 2], [2, 1, 1])
-
- def testBCast_6D(self):
- self._testBCastD([1, 3, 2], [2, 1, 1])
-
- def testBCast_7A(self):
- self._testBCastA([1, 3, 2], [1, 3, 1])
-
- def testBCast_7B(self):
- self._testBCastB([1, 3, 2], [1, 3, 1])
-
- def testBCast_7C(self):
- self._testBCastC([1, 3, 2], [1, 3, 1])
-
- def testBCast_7D(self):
- self._testBCastD([1, 3, 2], [1, 3, 1])
-
- def testBCast_8A(self):
- self._testBCastA([2, 1, 5], [2, 3, 1])
-
- def testBCast_8B(self):
- self._testBCastB([2, 1, 5], [2, 3, 1])
-
- def testBCast_8C(self):
- self._testBCastC([2, 1, 5], [2, 3, 1])
-
- def testBCast_8D(self):
- self._testBCastD([2, 1, 5], [2, 3, 1])
-
- def testBCast_9A(self):
- self._testBCastA([2, 0, 5], [2, 0, 1])
-
- def testBCast_9B(self):
- self._testBCastB([2, 0, 5], [2, 0, 1])
-
- def testBCast_9C(self):
- self._testBCastC([2, 0, 5], [2, 0, 1])
-
- def testBCast_9D(self):
- self._testBCastD([2, 0, 5], [2, 0, 1])
-
- def testBCast_10A(self):
- self._testBCastA([2, 3, 0], [2, 3, 1])
-
- def testBCast_10B(self):
- self._testBCastB([2, 3, 0], [2, 3, 1])
-
- def testBCast_10C(self):
- self._testBCastC([2, 3, 0], [2, 3, 1])
-
- def testBCast_10D(self):
- self._testBCastD([2, 3, 0], [2, 3, 1])
-
- def testBCast_11A(self):
- self._testBCastA([1, 3, 2], [1, 3, 2])
-
- def testBCast_11B(self):
- self._testBCastB([1, 3, 2], [1, 3, 2])
-
- def testBCast_11C(self):
- self._testBCastC([1, 3, 2], [1, 3, 2])
-
- def testBCast_11D(self):
- self._testBCastD([1, 3, 2], [1, 3, 2])
-
- def testBCast_12A(self):
- self._testBCastA([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
- def testBCast_12B(self):
- self._testBCastB([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
- def testBCast_12C(self):
- self._testBCastC([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
- def testBCast_12D(self):
- self._testBCastD([1, 1, 1, 1, 3, 2], [1, 3, 2])
-
- def testBCast_13A(self):
- self._testBCastA([1, 3, 2, 1, 1], [1])
-
- def testBCast_13B(self):
- self._testBCastB([1, 3, 2, 1, 1], [1])
-
- def testBCast_13C(self):
- self._testBCastC([1, 3, 2, 1, 1], [1])
-
- def testBCast_13D(self):
- self._testBCastD([1, 3, 2, 1, 1], [1])
-
- def testBCast_14A(self):
- self._testBCastA([2, 3, 1, 1, 5], [1])
-
- def testBCast_14B(self):
- self._testBCastB([2, 3, 1, 1, 5], [1])
-
- def testBCast_14C(self):
- self._testBCastC([2, 3, 1, 1, 5], [1])
-
- def testBCast_14D(self):
- self._testBCastD([2, 3, 1, 1, 5], [1])
-
- def testBCast_15A(self):
- self._testBCastA([10, 3, 1, 2], [3, 1, 2])
-
- def testBCast_15B(self):
- self._testBCastB([10, 3, 1, 2], [3, 1, 2])
-
- def testBCast_15C(self):
- self._testBCastC([10, 3, 1, 2], [3, 1, 2])
-
- def testBCast_15D(self):
- self._testBCastD([10, 3, 1, 2], [3, 1, 2])
-
- def testMismatchedDimensions(self):
- for func in [
- math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.div, _ADD,
- _SUB, _MUL, _TRUEDIV, _FLOORDIV
- ]:
- with self.assertRaisesWithPredicateMatch(
- ValueError, lambda e: "Dimensions must" in str(e)):
- func(
- ops.convert_to_tensor([10.0, 20.0, 30.0]),
- ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0]]))
-
- def testZeroPowGrad(self):
- with self.test_session():
- for dtype in (np.float16, np.float32, np.float64, np.complex64,
- np.complex128):
- x = constant_op.constant(0.0, dtype=dtype)
- y = constant_op.constant(2.0, dtype=dtype)
- z = math_ops.pow(x, y)
- error = gradient_checker.compute_gradient_error(y, [], z, [])
- self.assertEqual(error, 0)
-
- def testComplexPowGrad(self):
- with self.test_session():
- for dtype in np.complex64, np.complex128:
- for base in 2.0, -2.0:
- x = constant_op.constant(base, dtype=dtype)
- y = constant_op.constant(2.0, dtype=dtype)
- z = math_ops.pow(x, y)
- error = gradient_checker.compute_gradient_error(y, [], z, [])
- self.assertLess(error, 2e-4)
-
- def testAtan2SpecialValues(self):
- x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
- (1.2345, float("inf")), (1.2345, -float("inf")),
- (-4.321, float("inf")), (-4.125, -float("inf")),
- (float("inf"), float("inf")), (float("inf"), -float("inf")),
- (-float("inf"), float("inf")),
- (-float("inf"), -float("inf")))
- for dtype in np.float32, np.float64:
- x1 = np.array(x1l).astype(dtype)
- x2 = np.array(x2l).astype(dtype)
- self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
- self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
-
- def testPowNegativeExponent(self):
- for dtype in [np.int32, np.int64]:
- with self.test_session(use_gpu=False) as sess:
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- "Integers to negative integer powers are not allowed"):
- x = np.array([5, 2]).astype(dtype)
- y = np.array([-2, 3]).astype(dtype)
- sess.run(math_ops.pow(x, y))
-
- with self.test_session(use_gpu=False) as sess:
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- "Integers to negative integer powers are not allowed"):
- x = np.array([5, 2]).astype(dtype)
- y = np.array([2, -3]).astype(dtype)
- sess.run(math_ops.pow(x, y))
-
- with self.test_session(use_gpu=False) as sess:
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- "Integers to negative integer powers are not allowed"):
- x = np.array([5, 2]).astype(dtype)
- y = -3
- sess.run(math_ops.pow(x, y))
-
-
class ComparisonOpTest(test.TestCase):
def _compareScalar(self, func, x, y, dtype):
@@ -1470,7 +324,7 @@ class SelectOpTest(test.TestCase):
self.assertShapeEqual(np_ans, out)
def _compareGradientX(self, c, x, y, numeric_gradient_type=None):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny)
@@ -1494,7 +348,7 @@ class SelectOpTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _compareGradientY(self, c, x, y, numeric_gradient_type=None):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny)
@@ -1582,7 +436,7 @@ class SelectOpTest(test.TestCase):
x = np.random.rand(1, 3, 0) * 100
y = np.random.rand(1, 3, 0) * 100
z_expected = np.zeros((1, 3, 0), dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
xt = x.astype(np.float32)
yt = y.astype(np.float32)
z = array_ops.where(c, xt, yt).eval()
@@ -1590,7 +444,7 @@ class SelectOpTest(test.TestCase):
def testNan(self):
"""Verify that nans don't propagate where they shouldn't."""
- with self.test_session():
+ with self.cached_session():
for c in False, True:
for a in 7.0, np.nan:
for b in 5.0, np.nan:
@@ -1614,7 +468,7 @@ class BatchSelectOpTest(test.TestCase):
self.assertShapeEqual(np_ans, out)
def _compareGradientX(self, c, x, y, numeric_gradient_type=None):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny)
@@ -1638,7 +492,7 @@ class BatchSelectOpTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _compareGradientY(self, c, x, y, numeric_gradient_type=None):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = array_ops.where(c, inx, iny)
@@ -1745,7 +599,7 @@ class MinMaxOpTest(test.TestCase):
self._compare(x.astype(t), t(y), use_gpu=True)
def _compareGradientX(self, func, x, y):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = func(inx, iny)
@@ -1760,7 +614,7 @@ class MinMaxOpTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
def _compareGradientY(self, func, x, y):
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = func(inx, iny)
@@ -1932,7 +786,7 @@ class RoundingTest(test.TestCase):
def _compare_values(self, x, y=None):
y = np.rint(x) if y is None else np.asarray(y)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tf_rint = math_ops.rint(x)
np_rint = sess.run(tf_rint)
self.assertAllEqual(y, np_rint)
@@ -1940,7 +794,7 @@ class RoundingTest(test.TestCase):
def _compare(self, x):
np_floor, np_ceil = np.floor(x), np.ceil(x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inx = ops.convert_to_tensor(x)
ofloor, oceil = math_ops.floor(inx), math_ops.ceil(inx)
tf_floor, tf_ceil = sess.run([ofloor, oceil])
@@ -2099,7 +953,7 @@ class ComplexMakeRealImagTest(test.TestCase):
# computes the squared sum. This is obviously the same as sum(real
# * real) + sum(imag * imag). We just want to make sure the
# gradient function is checked.
- with self.test_session():
+ with self.cached_session():
inx = ops.convert_to_tensor(x)
real, imag = array_ops.split(value=inx, num_or_size_splits=2, axis=1)
real, imag = array_ops.reshape(real, [-1]), array_ops.reshape(imag, [-1])
@@ -2116,7 +970,7 @@ class ComplexMakeRealImagTest(test.TestCase):
def _compareBroadcastGradient(self, x):
x_ = ops.convert_to_tensor(x)
epsilon = 1e-3
- with self.test_session():
+ with self.cached_session():
for args in [(x_, 0.), (0., x_)]:
z = math_ops.reduce_sum(math_ops.abs(math_ops.complex(*args)))
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -2136,7 +990,7 @@ class ComplexMakeRealImagTest(test.TestCase):
# data is a float matrix of shape [n, 4]. data[:, 0], data[:, 1],
# data[:, 2], data[:, 3] are real parts of x, imaginary parts of
# x, real parts of y and imaginary parts of y.
- with self.test_session():
+ with self.cached_session():
inp = ops.convert_to_tensor(data)
xr, xi, yr, yi = array_ops.split(value=inp, num_or_size_splits=4, axis=1)
@@ -2166,7 +1020,7 @@ class ComplexMakeRealImagTest(test.TestCase):
class AccumulateTest(test.TestCase):
def testSimple(self):
- with self.test_session():
+ with self.cached_session():
random_arrays = [
np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20)
]
@@ -2181,20 +1035,20 @@ class AccumulateTest(test.TestCase):
self.assertAllClose(np_val, tf_val.eval())
def testZeroArgs(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
tf_val = math_ops.accumulate_n([])
tf_val.eval()
def testWrongShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
a = variables.Variable(0.2)
b = variables.Variable(0.1)
math_ops.accumulate_n([a, b], shape=[2, 2]) # Should be shape=[]
def testWrongType(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
b = variables.Variable(0.1, dtype=np.float32)
@@ -2202,7 +1056,7 @@ class AccumulateTest(test.TestCase):
def testWrongTypeOneInput(self):
# Scenario that used to trigger a bug, even when testWrongType() worked
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
a = variables.Variable(0.2, dtype=np.float32)
math_ops.accumulate_n([a], tensor_dtype=np.int32)
@@ -2214,7 +1068,7 @@ class PolyvalTest(test.TestCase):
x = np.random.rand(2, 2).astype(dtype)
coeffs = [np.random.rand(2, 2).astype(dtype) for _ in range(degree + 1)]
np_val = np.polyval(coeffs, x)
- with self.test_session():
+ with self.cached_session():
tf_val = math_ops.polyval(coeffs, x)
self.assertAllClose(np_val, tf_val.eval())
@@ -2237,7 +1091,7 @@ class PolyvalTest(test.TestCase):
for _ in range(degree + 1)
]
np_val = np.polyval(coeffs, x)
- with self.test_session():
+ with self.cached_session():
tf_val = math_ops.polyval(coeffs, x)
self.assertAllClose(np_val, tf_val.eval())
@@ -2245,7 +1099,7 @@ class PolyvalTest(test.TestCase):
x = np.random.rand(2, 2).astype(np.float32)
coeffs = []
np_val = np.polyval(coeffs, x)
- with self.test_session():
+ with self.cached_session():
tf_val = math_ops.polyval(coeffs, x)
self.assertAllClose(np_val, tf_val.eval())
diff --git a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
new file mode 100644
index 0000000000..77f182784e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
@@ -0,0 +1,541 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for unary coefficient-wise operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+_NEG = lambda x: -x
+_ABS = abs
+
+
+# TODO(zongheng): it'd be great to factor out this function and various random
+# SparseTensor gen funcs.
+def _sparsify(x, thresh=0.5, index_dtype=np.int64):
+ x[x < thresh] = 0
+
+ non_zero = np.where(x)
+ x_indices = np.vstack(non_zero).astype(index_dtype).T
+ x_values = x[non_zero]
+ x_shape = x.shape
+
+ return sparse_tensor.SparseTensor(
+ indices=x_indices, values=x_values, dense_shape=x_shape), x_values
+
+
+def _default_tolerance(dtype):
+ """Returns a sensible default tolerance for comparing results of a given type.
+
+ Args:
+ dtype: A datatype.
+ """
+ if dtype == np.float16:
+ return 5e-3
+ elif dtype in (np.float32, np.complex64):
+ return 1e-3
+ elif dtype in (np.float64, np.complex128):
+ return 1e-5
+ else:
+ return None # Fail fast for unexpected types
+
+
+class UnaryOpTest(test.TestCase):
+
+ def _compareCpu(self, x, np_func, tf_func, grad_rtol=None, grad_atol=None):
+ if grad_rtol is None:
+ grad_rtol = _default_tolerance(x.dtype)
+ if grad_atol is None:
+ grad_atol = _default_tolerance(x.dtype)
+ np_ans = np_func(x)
+ with self.test_session(use_gpu=False):
+ inx = ops.convert_to_tensor(x)
+ if x.dtype in (np.float32, np.float64,
+ dtypes_lib.bfloat16.as_numpy_dtype):
+ y = 1.1 * tf_func(inx)
+ np_ans *= 1.1
+ else:
+ y = tf_func(inx)
+ tf_cpu = y.eval()
+ self.assertShapeEqual(np_ans, y)
+ if x.dtype == np.float16:
+ self.assertAllClose(np_ans, tf_cpu, rtol=1e-3, atol=1e-3)
+ elif x.dtype == dtypes_lib.bfloat16.as_numpy_dtype:
+ self.assertAllClose(np_ans, tf_cpu, rtol=1e-2, atol=1e-2)
+ else:
+ self.assertAllClose(np_ans, tf_cpu)
+
+ if x.dtype in (np.complex64, np.complex128) and tf_func == math_ops.sign:
+ return # Return early
+
+ if x.dtype == np.float16:
+ s = list(np.shape(x))
+ jacob_t, _ = gradient_checker.compute_gradient(
+ inx, s, y, s, x_init_value=x)
+ xf = x.astype(np.float)
+ inxf = ops.convert_to_tensor(xf)
+ yf = tf_func(inxf)
+ _, jacob_n = gradient_checker.compute_gradient(
+ inxf, s, yf, s, x_init_value=xf, delta=1e-2)
+ jacob_n = jacob_n.astype(np.float16)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+ elif x.dtype in (np.float32, np.complex64):
+ s = list(np.shape(x))
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, s, y, s, x_init_value=x, delta=1e-3)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+ elif x.dtype in (np.float64, np.complex128):
+ s = list(np.shape(x))
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, s, y, s, x_init_value=x, delta=1e-5)
+ self.assertAllClose(jacob_t, jacob_n, rtol=grad_rtol, atol=grad_atol)
+
+ def _check(self, result_tensor, result_np, input_sp_t, tol):
+ self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
+ self.assertTrue(isinstance(input_sp_t, sparse_tensor.SparseTensor))
+ self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
+ self.assertAllEqual(input_sp_t.dense_shape.eval(),
+ result_tensor.dense_shape.eval())
+ if tol is None:
+ self.assertAllClose(result_np, result_tensor.values.eval())
+ else:
+ self.assertAllClose(
+ result_np, result_tensor.values.eval(), rtol=tol, atol=tol)
+
+ def _compareSparseCpu(self, x, np_func, tf_func, tol):
+ x_sp, x_sp_vals = _sparsify(x)
+ res_np = np_func(x_sp_vals)
+ with self.test_session(use_gpu=False):
+ self._check(tf_func(x_sp), res_np, x_sp, tol)
+
+ def _compareGpu(self, x, np_func, tf_func):
+ np_ans = np_func(x)
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ result = tf_func(ops.convert_to_tensor(x))
+ tf_gpu = result.eval()
+ if x.dtype == np.float16:
+ self.assertAllClose(np_ans, tf_gpu, rtol=1e-3, atol=1e-3)
+ else:
+ self.assertAllClose(np_ans, tf_gpu)
+ # TODO(zhifengc/ke): make gradient checker work on GPU.
+
+ def _compareSparseGpu(self, x, np_func, tf_func, tol):
+ x_sp, x_sp_vals = _sparsify(x)
+ res_np = np_func(x_sp_vals)
+ with self.test_session(force_gpu=test_util.is_gpu_available()):
+ self._check(tf_func(x_sp), res_np, x_sp, tol)
+
+ def _compareBoth(self, x, np_func, tf_func):
+ self._compareCpu(x, np_func, tf_func)
+ self._compareGpu(x, np_func, tf_func)
+
+ def _compareBothSparse(self, x, np_func, tf_func, tol=None):
+ self._compareSparseCpu(x, np_func, tf_func, tol)
+ self._compareSparseGpu(x, np_func, tf_func, tol)
+
+ def _inv(self, x):
+ return 1.0 / x
+
+ def _rsqrt(self, x):
+ return self._inv(np.sqrt(x))
+
+ def _sigmoid(self, x):
+ return 1.0 / (1.0 + np.exp(-x))
+
+ def _log_sigmoid(self, x):
+ return np.log(self._sigmoid(x))
+
+ def _replace_domain_error_with_inf(self, fn):
+
+ def func(x):
+ try:
+ return fn(x)
+ except ValueError as e:
+ if "domain error" in str(e):
+ return np.inf * np.ones_like(x)
+ else:
+ raise e
+
+ return func
+
+ def testFloatBasic(self):
+ x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
+ w = x - x.min() + 1.02 # all greater than 1
+ y = (x + .5).astype(np.float32) # no zero
+ z = (x + 15.5).astype(np.float32) # all positive
+ k = np.arange(-0.90, 0.90, 0.25).astype(np.float32) # between -1 and 1
+
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(y, self._inv, math_ops.reciprocal)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareBoth(z, np.sqrt, math_ops.sqrt)
+ self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareBoth(x, np.expm1, math_ops.expm1)
+ self._compareBoth(z, np.log, math_ops.log)
+ self._compareBoth(z, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ self._compareBoth(x, np.arcsinh, math_ops.asinh)
+ self._compareBoth(w, np.arccosh, math_ops.acosh)
+ self._compareBoth(k, np.arctanh, math_ops.atanh)
+ self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+ self._compareBoth(x, self._log_sigmoid, math_ops.log_sigmoid)
+ self._compareBoth(y, np.sign, math_ops.sign)
+ self._compareBoth(x, np.sin, math_ops.sin)
+ self._compareBoth(x, np.cos, math_ops.cos)
+ self._compareBoth(k, np.arcsin, math_ops.asin)
+ self._compareBoth(k, np.arccos, math_ops.acos)
+ self._compareBoth(x, np.arctan, math_ops.atan)
+ self._compareBoth(x, np.tan, math_ops.tan)
+ self._compareBoth(
+ y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ math_ops.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+ self._compareBothSparse(y, np.sign, math_ops.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
+
+ def testFloatTanhEdge(self):
+ x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ x = np.arange(-40, -40 + 6).reshape(6).astype(np.float32)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+
+ def testFloatEmpty(self):
+ x = np.empty((2, 0, 5), dtype=np.float32)
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(x, self._inv, math_ops.reciprocal)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareBoth(x, np.sqrt, math_ops.sqrt)
+ self._compareBoth(x, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareBoth(x, np.expm1, math_ops.expm1)
+ self._compareBoth(x, np.log, math_ops.log)
+ self._compareBoth(x, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.arcsinh, math_ops.asinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+ self._compareBoth(x, np.sign, math_ops.sign)
+ self._compareBoth(x, np.sin, math_ops.sin)
+ self._compareBoth(x, np.cos, math_ops.cos)
+ # Can't use vectorize below, so just use some arbitrary function
+ self._compareBoth(x, np.sign, math_ops.lgamma)
+ self._compareBoth(x, np.sign, math_ops.erf)
+ self._compareBoth(x, np.sign, math_ops.erfc)
+ self._compareBoth(x, np.tan, math_ops.tan)
+ self._compareBoth(x, np.arcsin, math_ops.asin)
+ self._compareBoth(x, np.arccos, math_ops.acos)
+ self._compareBoth(x, np.arctan, math_ops.atan)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.sqrt, math_ops.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+ self._compareBothSparse(x, np.sign, math_ops.sign)
+ self._compareBothSparse(x, np.sign, math_ops.erf)
+
+ def testDoubleBasic(self):
+ x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
+ w = x - x.min() + 1.02 # all greater than 1
+ y = (x + .5).astype(np.float64) # no zero
+ z = (x + 15.5).astype(np.float64) # all positive
+ k = np.arange(-0.90, 0.90,
+ 0.35).reshape(1, 3, 2).astype(np.float64) # between -1 and 1
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(y, self._inv, math_ops.reciprocal)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareBoth(z, np.sqrt, math_ops.sqrt)
+ self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareBoth(x, np.expm1, math_ops.expm1)
+ self._compareBoth(z, np.log, math_ops.log)
+ self._compareBoth(z, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.sinh, math_ops.sinh)
+ self._compareBoth(x, np.cosh, math_ops.cosh)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ self._compareBoth(x, np.arcsinh, math_ops.asinh)
+ self._compareBoth(w, np.arccosh, math_ops.acosh)
+ self._compareBoth(k, np.arctanh, math_ops.atanh)
+ self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+ self._compareBoth(y, np.sign, math_ops.sign)
+ self._compareBoth(x, np.sin, math_ops.sin)
+ self._compareBoth(x, np.cos, math_ops.cos)
+ self._compareBoth(
+ y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ math_ops.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+ self._compareBoth(x, np.arctan, math_ops.atan)
+ self._compareBoth(k, np.arcsin, math_ops.asin)
+ self._compareBoth(k, np.arccos, math_ops.acos)
+ self._compareBoth(k, np.tan, math_ops.tan)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+ self._compareBothSparse(y, np.sign, math_ops.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf)
+
+ def testHalfBasic(self):
+ x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
+ y = (x + .5).astype(np.float16) # no zero
+ z = (x + 15.5).astype(np.float16) # all positive
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(y, self._inv, math_ops.reciprocal)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareBoth(z, np.sqrt, math_ops.sqrt)
+ self._compareBoth(z, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareBoth(x, np.expm1, math_ops.expm1)
+ self._compareBoth(z, np.log, math_ops.log)
+ self._compareBoth(z, np.log1p, math_ops.log1p)
+ self._compareBoth(x, np.tanh, math_ops.tanh)
+ self._compareBoth(x, self._sigmoid, math_ops.sigmoid)
+ self._compareBoth(y, np.sign, math_ops.sign)
+ self._compareBoth(x, np.sin, math_ops.sin)
+ self._compareBoth(x, np.cos, math_ops.cos)
+ self._compareBoth(
+ y, np.vectorize(self._replace_domain_error_with_inf(math.lgamma)),
+ math_ops.lgamma)
+ self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
+ self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
+ try:
+ from scipy import special # pylint: disable=g-import-not-at-top
+ self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
+ self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
+ except ImportError as e:
+ tf_logging.warn("Cannot test special functions: %s" % str(e))
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(z, np.sqrt, math_ops.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+ self._compareBothSparse(y, np.sign, math_ops.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), math_ops.erf, tol=1e-3)
+
+ def testInt32Basic(self):
+ x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
+ self._compareCpu(x, np.abs, math_ops.abs)
+ self._compareCpu(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareBoth(x, np.square, math_ops.square)
+ self._compareCpu(x, np.sign, math_ops.sign)
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.sign, math_ops.sign)
+
+ def testInt64Basic(self):
+ x = np.arange(-6 << 40, 6 << 40, 2 << 40).reshape(1, 3, 2).astype(np.int64)
+ self._compareCpu(x, np.abs, math_ops.abs)
+ self._compareCpu(x, np.abs, _ABS)
+ self._compareCpu(x, np.negative, math_ops.negative)
+ self._compareCpu(x, np.negative, _NEG)
+ self._compareCpu(x, np.sign, math_ops.sign)
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.sign, math_ops.sign)
+
+ def testInt64Square(self):
+ x = np.arange(-6 << 20, 6 << 20, 2 << 20).reshape(1, 3, 2).astype(np.int64)
+ self._compareCpu(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.square, math_ops.square)
+
+ def testComplex64Basic(self):
+ x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
+ np.complex64)
+ y = x + np.complex(0.5, 0.5) # no zeros
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareCpu(y, self._inv, math_ops.reciprocal)
+ self._compareCpu(x, np.square, math_ops.square)
+ self._compareCpu(y, np.sqrt, math_ops.sqrt)
+ self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareCpu(x, np.expm1, math_ops.expm1)
+ self._compareCpu(y, np.log, math_ops.log)
+ self._compareCpu(y, np.log1p, math_ops.log1p)
+ self._compareCpu(x, np.sinh, math_ops.sinh)
+ self._compareCpu(x, np.cosh, math_ops.cosh)
+ self._compareCpu(x, np.tanh, math_ops.tanh)
+
+ # Complex64 versions of asinh() and acosh() in libstdc++ only have 6 digits
+ # of precision.
+ # Small gradient values + low precision --> High relative error
+ self._compareCpu(y, np.arcsinh, math_ops.asinh, grad_rtol=1e-2)
+ self._compareCpu(y, np.arccosh, math_ops.acosh, grad_rtol=1e-2)
+
+ self._compareCpu(y, np.arctanh, math_ops.atanh)
+ self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
+ self._compareCpu(x, np.sin, math_ops.sin)
+ self._compareCpu(x, np.cos, math_ops.cos)
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+
+ # Numpy uses an incorrect definition of sign; use the right one instead.
+ def complex_sign(x):
+ return x / np.abs(x)
+
+ self._compareBoth(y, complex_sign, math_ops.sign)
+ self._compareBothSparse(y, complex_sign, math_ops.sign)
+
+ def testComplex128Basic(self):
+ x = np.complex(1, 1) * np.arange(-3, 3).reshape(1, 3, 2).astype(
+ np.complex128)
+ y = x + np.complex(0.5, 0.5) # no zeros
+ self._compareBoth(x, np.abs, math_ops.abs)
+ self._compareBoth(x, np.abs, _ABS)
+ self._compareBoth(x, np.negative, math_ops.negative)
+ self._compareBoth(x, np.negative, _NEG)
+ self._compareCpu(y, self._inv, math_ops.reciprocal)
+ self._compareCpu(x, np.square, math_ops.square)
+ self._compareCpu(y, np.sqrt, math_ops.sqrt)
+ self._compareCpu(y, self._rsqrt, math_ops.rsqrt)
+ self._compareBoth(x, np.exp, math_ops.exp)
+ self._compareCpu(x, np.expm1, math_ops.expm1)
+ self._compareCpu(y, np.log, math_ops.log)
+ self._compareCpu(y, np.log1p, math_ops.log1p)
+ self._compareCpu(x, np.sinh, math_ops.sinh)
+ self._compareCpu(x, np.cosh, math_ops.cosh)
+ self._compareCpu(x, np.tanh, math_ops.tanh)
+ self._compareCpu(y, np.arcsinh, math_ops.asinh)
+ self._compareCpu(y, np.arccosh, math_ops.acosh)
+ self._compareCpu(y, np.arctanh, math_ops.atanh)
+ self._compareCpu(x, self._sigmoid, math_ops.sigmoid)
+ self._compareCpu(x, np.sin, math_ops.sin)
+ self._compareCpu(x, np.cos, math_ops.cos)
+
+ self._compareBothSparse(x, np.abs, math_ops.abs)
+ self._compareBothSparse(x, np.negative, math_ops.negative)
+ self._compareBothSparse(x, np.square, math_ops.square)
+ self._compareBothSparse(x, np.sqrt, math_ops.sqrt, 1e-3)
+ self._compareBothSparse(x, np.tanh, math_ops.tanh)
+
+ # Numpy uses an incorrect definition of sign; use the right one instead.
+ def complex_sign(x):
+ return x / np.abs(x)
+
+ self._compareBoth(y, complex_sign, math_ops.sign)
+ self._compareBothSparse(y, complex_sign, math_ops.sign)
+
+ def testGradGrad(self):
+ np.random.seed(7)
+ shape = (5,)
+ dtype_tols = [(np.float32, 5e-4), (np.float64, 1e-6), (np.complex64, 5e-4),
+ (np.complex128, 1e-6)]
+ op_range = [
+ (gen_math_ops.reciprocal_grad, [-2, 2]),
+ (gen_math_ops.rsqrt_grad, [0.1, 3]),
+ (gen_math_ops.sigmoid_grad, [-2, 2]),
+ (gen_math_ops.sqrt_grad, [0.1, 3]),
+ (gen_math_ops.tanh_grad, [-2, 2]),
+ ]
+
+ def rand(dtype, real_range):
+ x = np.random.uniform(
+ real_range[0], real_range[1], size=shape[0]).astype(dtype)
+ if dtype in (np.complex64, np.complex128):
+ x += 1j * np.random.uniform(-2, 2, size=shape[0]).astype(dtype)
+ return x
+
+ for op, real_range in op_range:
+ with self.cached_session():
+ for dtype, tol in dtype_tols:
+ x = constant_op.constant(rand(dtype, real_range))
+ y = constant_op.constant(rand(dtype, real_range))
+ z = op(x, y)
+ grads = gradient_checker.compute_gradient(
+ [x, y], [shape, shape],
+ z,
+ shape,
+ x_init_value=[rand(dtype, real_range),
+ rand(dtype, real_range)])
+ if isinstance(grads, tuple):
+ grads = [grads]
+ for analytical, numerical in grads:
+ self.assertAllClose(analytical, numerical, rtol=tol, atol=tol)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py
index 35f8f76991..eebaffbe13 100644
--- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py
@@ -60,7 +60,7 @@ class DecodeBmpOpTest(test.TestCase):
img_in = constant_op.constant(byte_string, dtype=dtypes.string)
decode = array_ops.squeeze(image_ops.decode_bmp(img_in))
- with self.test_session():
+ with self.cached_session():
decoded = decode.eval()
self.assertAllEqual(decoded, img_bytes)
@@ -135,7 +135,7 @@ class DecodeBmpOpTest(test.TestCase):
img_in = constant_op.constant(byte_string, dtype=dtypes.string)
decode = image_ops.decode_bmp(img_in)
- with self.test_session():
+ with self.cached_session():
decoded = decode.eval()
self.assertAllEqual(decoded, img_bytes)
diff --git a/tensorflow/python/kernel_tests/decode_compressed_op_test.py b/tensorflow/python/kernel_tests/decode_compressed_op_test.py
index c9bda58ca7..1cc1c7da30 100644
--- a/tensorflow/python/kernel_tests/decode_compressed_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_compressed_op_test.py
@@ -44,7 +44,7 @@ class DecodeCompressedOpTest(test.TestCase):
def testDecompress(self):
for compression_type in ["ZLIB", "GZIP", ""]:
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[2])
decompressed = parsing_ops.decode_compressed(
in_bytes, compression_type=compression_type)
@@ -57,7 +57,7 @@ class DecodeCompressedOpTest(test.TestCase):
def testDecompressWithRaw(self):
for compression_type in ["ZLIB", "GZIP", ""]:
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decompressed = parsing_ops.decode_compressed(
in_bytes, compression_type=compression_type)
diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py
index 4f49d72676..e9307a6b2f 100644
--- a/tensorflow/python/kernel_tests/decode_csv_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py
@@ -20,28 +20,30 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class DecodeCSVOpTest(test.TestCase):
def _test(self, args, expected_out=None, expected_err_re=None):
- with self.test_session() as sess:
+ if expected_err_re is None:
decode = parsing_ops.decode_csv(**args)
-
- if expected_err_re is None:
- out = sess.run(decode)
-
- for i, field in enumerate(out):
- if field.dtype == np.float32 or field.dtype == np.float64:
- self.assertAllClose(field, expected_out[i])
- else:
- self.assertAllEqual(field, expected_out[i])
-
- else:
- with self.assertRaisesOpError(expected_err_re):
- sess.run(decode)
+ out = self.evaluate(decode)
+
+ for i, field in enumerate(out):
+ if field.dtype == np.float32 or field.dtype == np.float64:
+ self.assertAllClose(field, expected_out[i])
+ else:
+ self.assertAllEqual(field, expected_out[i])
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ decode = parsing_ops.decode_csv(**args)
+ self.evaluate(decode)
def testSimple(self):
args = {
@@ -53,6 +55,31 @@ class DecodeCSVOpTest(test.TestCase):
self._test(args, expected_out)
+ def testSimpleWithScalarDefaults(self):
+ args = {
+ "records": ["1,4", "2,5", "3,6"],
+ "record_defaults": [1, 2],
+ }
+
+ expected_out = [[1, 2, 3], [4, 5, 6]]
+
+ self._test(args, expected_out)
+
+ def testSimpleWith2DDefaults(self):
+ args = {
+ "records": ["1", "2", "3"],
+ "record_defaults": [[[0]]],
+ }
+
+ if context.executing_eagerly():
+ err_spec = errors.InvalidArgumentError, (
+ "Each record default should be at "
+ "most rank 1.")
+ else:
+ err_spec = ValueError, "Shape must be at most rank 1 but is rank 2"
+ with self.assertRaisesWithPredicateMatch(*err_spec):
+ self._test(args)
+
def testSimpleNoQuoteDelimiter(self):
args = {
"records": ["1", "2", '"3"'],
diff --git a/tensorflow/python/kernel_tests/decode_image_op_test.py b/tensorflow/python/kernel_tests/decode_image_op_test.py
index 58280432d6..7f73fbaa84 100644
--- a/tensorflow/python/kernel_tests/decode_image_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_image_op_test.py
@@ -111,7 +111,7 @@ class DecodeImageOpTest(test.TestCase):
def testInvalidBytes(self):
image_bytes = b"ThisIsNotAnImage!"
decode = image_ops.decode_image(image_bytes)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
decode.eval()
diff --git a/tensorflow/python/kernel_tests/decode_png_op_test.py b/tensorflow/python/kernel_tests/decode_png_op_test.py
index d2e03938ee..8f36343667 100644
--- a/tensorflow/python/kernel_tests/decode_png_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_png_op_test.py
@@ -46,7 +46,7 @@ class DecodePngOpTest(test.TestCase):
image_ops.decode_png(
img_in, dtype=dtypes.uint16))
- with self.test_session():
+ with self.cached_session():
decoded = decode.eval()
self.assertAllEqual(decoded, img_bytes)
diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py
index 122a9ed469..dcc984811c 100644
--- a/tensorflow/python/kernel_tests/decode_raw_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class DecodeRawOpTest(test.TestCase):
def testToUint8(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[2])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint8)
self.assertEqual([2, None], decode.get_shape().as_list())
@@ -47,7 +47,7 @@ class DecodeRawOpTest(test.TestCase):
decode.eval(feed_dict={in_bytes: ["short", "longer"]})
def testToInt16(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.int16)
self.assertEqual([None, None], decode.get_shape().as_list())
@@ -62,7 +62,7 @@ class DecodeRawOpTest(test.TestCase):
decode.eval(feed_dict={in_bytes: ["123", "456"]})
def testEndianness(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode_le = parsing_ops.decode_raw(
in_bytes, out_type=dtypes.int32, little_endian=True)
@@ -74,18 +74,18 @@ class DecodeRawOpTest(test.TestCase):
self.assertAllEqual([[0x01020304]], result)
def testToFloat16(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.float16)
self.assertEqual([None, None], decode.get_shape().as_list())
- expected_result = np.matrix([[1, -2, -3, 4]], dtype=np.float16)
+ expected_result = np.matrix([[1, -2, -3, 4]], dtype="<f2")
result = decode.eval(feed_dict={in_bytes: [expected_result.tostring()]})
self.assertAllEqual(expected_result, result)
def testEmptyStringInput(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.float16)
@@ -94,7 +94,7 @@ class DecodeRawOpTest(test.TestCase):
self.assertEqual((num_inputs, 0), result.shape)
def testToUInt16(self):
- with self.test_session():
+ with self.cached_session():
in_bytes = array_ops.placeholder(dtypes.string, shape=[None])
decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.uint16)
self.assertEqual([None, None], decode.get_shape().as_list())
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
index d33bf1ba12..affbaf159d 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
@@ -33,7 +33,7 @@ class AssignOpTest(test.TestCase):
# contain benign and deliberate data races when multiple threads update
# the same parameters without a lock.
def testParallelUpdateWithoutLocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ones_t = array_ops.fill([1024, 1024], 1.0)
p = variables.Variable(array_ops.zeros([1024, 1024]))
adds = [
@@ -60,7 +60,7 @@ class AssignOpTest(test.TestCase):
self.assertTrue((vals <= ones * 20).all())
def testParallelAssignWithoutLocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ones_t = array_ops.fill([1024, 1024], float(1))
p = variables.Variable(array_ops.zeros([1024, 1024]))
assigns = [
@@ -92,7 +92,7 @@ class AssignOpTest(test.TestCase):
# returning the output tensors. This issue will be resolved with the new
# resource variables.
def testParallelUpdateWithLocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
zeros_t = array_ops.fill([1024, 1024], 0.0)
ones_t = array_ops.fill([1024, 1024], 1.0)
p = variables.Variable(zeros_t)
@@ -119,7 +119,7 @@ class AssignOpTest(test.TestCase):
self.assertAllEqual(vals, ones * 20)
def testParallelAssignWithLocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
zeros_t = array_ops.fill([1024, 1024], 0.0)
ones_t = array_ops.fill([1024, 1024], 1.0)
p = variables.Variable(zeros_t)
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
index 4dda9f093b..06c3271850 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -85,7 +85,7 @@ class AssignOpTest(test.TestCase):
self._testTypes(np.arange(0, 20).reshape([4, 5]))
def testAssignNonStrictShapeChecking(self):
- with self.test_session():
+ with self.cached_session():
data = array_ops.fill([1024, 1024], 0)
p = variables.Variable([1])
a = state_ops.assign(p, data, validate_shape=False)
@@ -99,14 +99,14 @@ class AssignOpTest(test.TestCase):
self.assertAllEqual(p.eval(), data2.eval())
def testInitRequiredAssignAdd(self):
- with self.test_session():
+ with self.cached_session():
p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
a.op.run()
def testInitRequiredAssignSub(self):
- with self.test_session():
+ with self.cached_session():
p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 5741f2ec64..6d1ead20be 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -128,7 +128,7 @@ class DepthwiseConv2DTest(test.TestCase):
x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
ops.reset_default_graph()
graph = ops.get_default_graph()
- with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ with self.session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-2,
dtypes.float32: 1e-8,
@@ -191,7 +191,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
"%s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
tf_logging.info("Testing without grouped_conv")
self._VerifyValues(
input_size, filter_size, stride, padding, data_type, use_gpu=True)
@@ -227,7 +227,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._VerifyValues(
input_size,
filter_size,
@@ -366,7 +366,7 @@ class DepthwiseConv2DTest(test.TestCase):
filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
ops.reset_default_graph()
graph = ops.get_default_graph()
- with self.test_session(graph=graph, use_gpu=use_gpu) as sess:
+ with self.session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-0,
dtypes.float32: 8e-4,
@@ -434,7 +434,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
"padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -465,7 +465,7 @@ class DepthwiseConv2DTest(test.TestCase):
"Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -483,7 +483,7 @@ class DepthwiseConv2DTest(test.TestCase):
tf_logging.info(
"Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
"%d, padding: %s", index, input_size, filter_size, stride, padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
@@ -504,7 +504,7 @@ class DepthwiseConv2DTest(test.TestCase):
"Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
"stride: %d, padding: %s", index, input_size, filter_size, stride,
padding)
- for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ for data_type in [dtypes.float32, dtypes.float64]:
self._ConstructAndTestGradient(
input_size,
filter_size,
diff --git a/tensorflow/python/kernel_tests/division_future_test.py b/tensorflow/python/kernel_tests/division_future_test.py
index e681b32856..e477bdc73b 100644
--- a/tensorflow/python/kernel_tests/division_future_test.py
+++ b/tensorflow/python/kernel_tests/division_future_test.py
@@ -50,7 +50,7 @@ class DivisionTestCase(test.TestCase):
self.assertEqual(x, y)
checks.append(f)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for dtype in dtypes:
for x in map(dtype, values):
for y in map(dtype, values):
diff --git a/tensorflow/python/kernel_tests/division_past_test.py b/tensorflow/python/kernel_tests/division_past_test.py
index 9ddd62e63c..63951b5b38 100644
--- a/tensorflow/python/kernel_tests/division_past_test.py
+++ b/tensorflow/python/kernel_tests/division_past_test.py
@@ -49,7 +49,7 @@ class DivisionTestCase(test.TestCase):
self.assertEqual(x, y)
checks.append(f)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for dtype in dtypes:
for x in map(dtype, values):
for y in map(dtype, values):
diff --git a/tensorflow/python/kernel_tests/duplicate_op_test.py b/tensorflow/python/kernel_tests/duplicate_op_test.py
index 529d3dd0b3..654267a582 100644
--- a/tensorflow/python/kernel_tests/duplicate_op_test.py
+++ b/tensorflow/python/kernel_tests/duplicate_op_test.py
@@ -34,7 +34,7 @@ class DuplicateOpTest(test.TestCase):
self.assertEqual(len(duplicate.OP_LIST.op), 0)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(math_ops.add(1, 41).eval(), 42)
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 5e8937ad2c..9557e30993 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -288,7 +288,7 @@ class DynamicPartitionTest(test.TestCase):
self.assertAllEqual([], partition_vals[i])
def testErrorIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
indices = constant_op.constant([0, 2, 99, 2, 2])
@@ -298,7 +298,7 @@ class DynamicPartitionTest(test.TestCase):
sess.run(partitions)
def testScalarIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bad = 17
data = np.zeros(5)
partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7)
@@ -306,7 +306,7 @@ class DynamicPartitionTest(test.TestCase):
sess.run(partitions)
def testHigherRankIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = (2, 3)
indices = array_ops.placeholder(shape=shape, dtype=np.int32)
data = np.zeros(shape + (5,))
@@ -334,7 +334,7 @@ class DynamicPartitionTest(test.TestCase):
inds += [13]*194 + [14]*194 + [15]*192
self.assertEqual(len(inds), x.shape[0])
partitioned = data_flow_ops.dynamic_partition(x, inds, 16)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
res = sess.run(partitioned)
self.assertEqual(res[-1].shape[0], 192)
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index 49b9569e2b..3a1036e52a 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -252,7 +252,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
# GPU version unit tests
def testScalarGPU(self):
- with self.test_session():
+ with self.cached_session():
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40.0), constant_op.constant(60.0)]
for step in -1, 1:
@@ -263,7 +263,7 @@ class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
self.assertEqual([2], stitched_t.get_shape().as_list())
def testHigherRankGPU(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
indices = [
constant_op.constant(6),
constant_op.constant([4, 1]),
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index dcd435e1ff..40b8548cea 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -242,7 +242,7 @@ class EmbeddingLookupTest(test.TestCase):
# vector is going to be empty. The subsequent DivOp fails because of that.
# TODO(keveman): Disabling the test until the underlying problem is fixed.
def testSimpleSharded(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 2
vocab_size = 4
p, params, feed_dict = _EmbeddingParams(num_shards, vocab_size)
@@ -258,7 +258,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testMaxNorm(self):
- with self.test_session():
+ with self.cached_session():
embeddings = constant_op.constant([[2.0]])
ids = constant_op.constant([0], dtype=dtypes.int32)
@@ -268,7 +268,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertAllEqual(embedding.eval(), [[1.0]])
def testMaxNormNontrivial(self):
- with self.test_session():
+ with self.cached_session():
embeddings = constant_op.constant([[2.0, 4.0], [3.0, 1.0]])
ids = constant_op.constant([0, 1], dtype=dtypes.int32)
@@ -281,7 +281,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertAllEqual(embedding.eval(), 2 * normalized.eval())
def testSimpleShardedPartitionedVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_shards = 2
vocab_size = 4
p, p_variable, params, feed_dict = _EmbeddingParamsAsPartitionedVariable(
@@ -303,7 +303,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testSimpleShardedPartitionedResourceVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_shards = 2
vocab_size = 4
p, p_variable, params, _ = _EmbeddingParamsAsPartitionedVariable(
@@ -326,7 +326,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedModPartitioningInt32Ids(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -348,7 +348,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedModPartitioningInt64Ids(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -370,7 +370,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningInt32Ids(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -394,7 +394,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningInt32IdsPartitionedVariable(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -419,7 +419,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningInt64Ids(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -443,7 +443,7 @@ class EmbeddingLookupTest(test.TestCase):
self.assertShapeEqual(np_result, embedding)
def testShardedDivPartitioningUnknownParamShape(self):
- with self.test_session():
+ with self.cached_session():
num_shards = 5
vocab_size = 13
# Embedding dimensions is 10. The vocab_size x 10 embedding
@@ -475,7 +475,7 @@ class EmbeddingLookupTest(test.TestCase):
tf_logging.vlog(1, id_vals)
for ids_shape in [(10,), (2, 5)]:
for num_shards in [1, 3]:
- with self.test_session():
+ with self.cached_session():
ids = constant_op.constant(
id_vals, shape=ids_shape, dtype=dtypes.int32)
x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
@@ -494,7 +494,7 @@ class EmbeddingLookupTest(test.TestCase):
id_vals = list(np.random.randint(vocab_size, size=num_ids))
tf_logging.vlog(1, id_vals)
for num_shards in [1, 3]:
- with self.test_session():
+ with self.cached_session():
ids = constant_op.constant(id_vals, dtype=dtypes.int32)
x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
# This will force a conversion from IndexedSlices to Tensor.
@@ -528,7 +528,7 @@ class EmbeddingLookupTest(test.TestCase):
def testHigherRank(self):
np.random.seed(8)
- with self.test_session():
+ with self.cached_session():
for params_shape in (12,), (6, 3):
params = np.random.randn(*params_shape)
for ids_shape in (3, 2), (4, 3):
@@ -548,7 +548,7 @@ class EmbeddingLookupTest(test.TestCase):
def testHigherRankMaxNorm(self):
np.random.seed(8)
- with self.test_session():
+ with self.cached_session():
for params_shape in (12,), (6, 3), (6, 2, 3):
# Test embedding rank 0, 1, 2.
# Note: the first dimension must be a common multiple of procs below.
@@ -581,7 +581,7 @@ class EmbeddingLookupTest(test.TestCase):
# It always applies max_norm.
np.random.seed(8)
l2_norm = 2.
- with self.test_session():
+ with self.cached_session():
# Param values are in [l2_norm, l2_norm+1) so it will always clip.
params = np.random.rand(6, 3) + l2_norm
params_norm = l2_norm * params / np.sqrt(
@@ -667,7 +667,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
[dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64],
[True, False]):
- with self.test_session():
+ with self.cached_session():
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
embedding_sum = embedding_ops.embedding_lookup_sparse(
@@ -716,7 +716,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
for num_shards, combiner, dtype, ignore_weights in itertools.product(
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
@@ -734,7 +734,7 @@ class EmbeddingLookupSparseTest(test.TestCase):
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
sp_ids = sparse_tensor.SparseTensor(
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
@@ -819,7 +819,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
return sparse_ids, sparse_weights
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -832,7 +832,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
def test_safe_embedding_lookup_sparse_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -846,7 +846,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][2], embedding_weights[0][3]])
def test_safe_embedding_lookup_sparse_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_2d()
@@ -860,7 +860,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_2d()
@@ -874,7 +874,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
(embedding_weights[0] + embedding_weights[1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -889,7 +889,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, sparse_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -902,7 +902,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -918,7 +918,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_3d()
@@ -934,7 +934,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_3d()
@@ -951,7 +951,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -1035,7 +1035,7 @@ class DynamicStitchOpTest(test.TestCase):
# We expect that the values are merged in order.
def testStitchOrder(self):
- with self.test_session():
+ with self.cached_session():
indices = []
np_values = []
values = []
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
index e1f5a6b620..7d9d4e5175 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
@@ -83,7 +83,7 @@ class ExtractImagePatchesGradTest(test.TestCase):
random_seed = 42
random_seed_lib.set_random_seed(random_seed)
- with self.test_session():
+ with self.cached_session():
for test_case in self._TEST_CASES:
np.random.seed(random_seed)
in_shape = test_case['in_shape']
diff --git a/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py b/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
new file mode 100644
index 0000000000..64757a3e07
--- /dev/null
+++ b/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
@@ -0,0 +1,131 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for ExtractVolumePatches op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+class ExtractVolumePatches(test.TestCase):
+ """Functional tests for ExtractVolumePatches op."""
+
+ def _VerifyValues(self, image, ksizes, strides, padding, patches):
+ """Tests input-output pairs for the ExtractVolumePatches op.
+
+ Args:
+ image: Input tensor with shape:
+ [batch, in_planes, in_rows, in_cols, depth].
+ ksizes: Patch size specified as: [ksize_planes, ksize_rows, ksize_cols].
+ strides: Output strides, specified as:
+ [stride_planes, stride_rows, stride_cols].
+ padding: Padding type.
+ patches: Expected output.
+
+ Note:
+ rates are not supported as of now.
+ """
+ ksizes = [1] + ksizes + [1]
+ strides = [1] + strides + [1]
+
+ with self.test_session(use_gpu=True):
+ out_tensor = array_ops.extract_volume_patches(
+ constant_op.constant(image),
+ ksizes=ksizes,
+ strides=strides,
+ padding=padding,
+ name="im2col_3d")
+ self.assertAllClose(patches, out_tensor.eval())
+
+ # pylint: disable=bad-whitespace
+ def testKsize1x1x1Stride1x1x1(self):
+ """Verifies that for 1x1x1 kernel the output equals the input."""
+ image = np.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + 1
+ patches = image
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 1],
+ strides=[1, 1, 1],
+ padding=padding,
+ patches=patches)
+
+ def testKsize1x1x1Stride2x3x4(self):
+ """Test for 1x1x1 kernel and strides."""
+ image = np.arange(6 * 2 * 4 * 5 * 3).reshape([6, 2, 4, 5, 3]) + 1
+ patches = image[:, ::2, ::3, ::4, :]
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 1],
+ strides=[2, 3, 4],
+ padding=padding,
+ patches=patches)
+
+ def testKsize1x1x2Stride2x2x3(self):
+ """Test for 1x1x2 kernel and strides."""
+ image = np.arange(45).reshape([1, 3, 3, 5, 1]) + 1
+ patches = np.array([[[[[ 1, 2],
+ [ 4, 5]],
+ [[11, 12],
+ [14, 15]]],
+ [[[31, 32],
+ [34, 35]],
+ [[41, 42],
+ [44, 45]]]]])
+ for padding in ["VALID", "SAME"]:
+ self._VerifyValues(
+ image,
+ ksizes=[1, 1, 2],
+ strides=[2, 2, 3],
+ padding=padding,
+ patches=patches)
+
+ def testKsize2x2x2Stride1x1x1Valid(self):
+ """Test for 2x2x2 kernel with VALID padding."""
+ image = np.arange(8).reshape([1, 2, 2, 2, 1]) + 1
+ patches = np.array([[[[[1, 2, 3, 4, 5, 6, 7, 8]]]]])
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2, 2],
+ strides=[1, 1, 1],
+ padding="VALID",
+ patches=patches)
+
+ def testKsize2x2x2Stride1x1x1Same(self):
+ """Test for 2x2x2 kernel with SAME padding."""
+ image = np.arange(8).reshape([1, 2, 2, 2, 1]) + 1
+ patches = np.array([[[[[1, 2, 3, 4, 5, 6, 7, 8],
+ [2, 0, 4, 0, 6, 0, 8, 0]],
+ [[3, 4, 0, 0, 7, 8, 0, 0],
+ [4, 0, 0, 0, 8, 0, 0, 0]]],
+ [[[5, 6, 7, 8, 0, 0, 0, 0],
+ [6, 0, 8, 0, 0, 0, 0, 0]],
+ [[7, 8, 0, 0, 0, 0, 0, 0],
+ [8, 0, 0, 0, 0, 0, 0, 0]]]]])
+ self._VerifyValues(
+ image,
+ ksizes=[2, 2, 2],
+ strides=[1, 1, 1],
+ padding="SAME",
+ patches=patches)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/fft_ops_test.py b/tensorflow/python/kernel_tests/fft_ops_test.py
index 629acedda5..f117934e4b 100644
--- a/tensorflow/python/kernel_tests/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/fft_ops_test.py
@@ -496,7 +496,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
"Input dimension .* must have length of at least 6 but got: 5"):
x = np.zeros((5,) * rank).astype(np.float32)
fft_length = [6] * rank
- with self.test_session():
+ with self.cached_session():
rfft_fn(x, fft_length).eval()
with self.assertRaisesWithPredicateMatch(
@@ -504,7 +504,7 @@ class RFFTOpsTest(BaseFFTOpsTest):
"Input dimension .* must have length of at least .* but got: 3"):
x = np.zeros((3,) * rank).astype(np.complex64)
fft_length = [6] * rank
- with self.test_session():
+ with self.cached_session():
irfft_fn(x, fft_length).eval()
def testGrad_Simple(self):
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index 9e7b528338..a5f8f64e0c 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -99,19 +99,19 @@ class FIFOQueueTest(test.TestCase):
""", q.queue_ref.op.node_def)
def testEnqueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueHalf(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float16)
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
enqueue_correct_op.run()
@@ -120,7 +120,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(1, q.size().eval())
def testEnqueueManyWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(
10, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (2,)])
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
@@ -143,7 +143,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(self.evaluate(q.dequeue()), 1)
def testEnqueueDictWithoutNames(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
with self.assertRaisesRegexp(ValueError, "must have names"):
q.enqueue({"a": 12.0})
@@ -151,7 +151,7 @@ class FIFOQueueTest(test.TestCase):
q.enqueue_many({"a": [12.0, 13.0]})
def testParallelEnqueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -177,7 +177,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testParallelDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -201,7 +201,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -215,7 +215,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([elems[i]], vals)
def testDequeueHalf(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float16)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -229,7 +229,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -259,7 +259,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([elem], result)
def testMultiEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
@@ -275,12 +275,12 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([y], y_val)
def testQueueSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
self.assertEqual([0], q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
@@ -293,7 +293,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, size.eval())
def testEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -306,7 +306,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([elems[i % 4]], vals)
def testEmptyEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
empty_t = constant_op.constant(
[], dtype=dtypes_lib.float32, shape=[0, 2, 3])
@@ -318,7 +318,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([0], size_t.eval())
def testEmptyDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=())
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue_many(0)
@@ -328,7 +328,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueUpTo(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=())
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue_up_to(0)
@@ -338,14 +338,14 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueManyWithNoShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
# Expect the operation to fail due to the shape not being constrained.
with self.assertRaisesOpError("specified shapes"):
q.dequeue_many(0).eval()
def testMultiEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.float32, dtypes_lib.int32))
float_elems = [10.0, 20.0, 30.0, 40.0]
int_elems = [[1, 2], [3, 4], [5, 6], [7, 8]]
@@ -361,7 +361,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(int_elems[i % 4], int_val)
def testDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -373,7 +373,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elems[4:8], dequeued_t.eval())
def testDequeueUpToNoBlocking(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -385,7 +385,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elems[4:8], dequeued_t.eval())
def testMultiDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
float_elems = [
@@ -416,7 +416,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
def testMultiDequeueUpToNoBlocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
float_elems = [
@@ -440,7 +440,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(int_elems[4:8], int_val)
def testHighDimension(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, (4, 4, 4, 4))
elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
enqueue_op = q.enqueue_many((elems,))
@@ -494,7 +494,7 @@ class FIFOQueueTest(test.TestCase):
array_ops.placeholder(dtypes_lib.int32)))
def testEnqueueWrongShapeAtRuntime(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.int32), (
(2, 2), (3, 3)))
elems_ok = np.array([1] * 4).reshape((2, 2)).astype(np.int32)
@@ -506,7 +506,7 @@ class FIFOQueueTest(test.TestCase):
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
def testEnqueueDequeueManyWrongShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.int32), (
(2, 2), (3, 3)))
elems_ok = np.array([1] * 8).reshape((2, 2, 2)).astype(np.int32)
@@ -521,7 +521,7 @@ class FIFOQueueTest(test.TestCase):
dequeued_t.eval()
def testParallelEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(100)]
enqueue_op = q.enqueue_many((elems,))
@@ -540,7 +540,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
def testParallelDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(1000)]
enqueue_op = q.enqueue_many((elems,))
@@ -562,7 +562,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(1000, dtypes_lib.float32, shapes=())
elems = [10.0 * x for x in range(1000)]
enqueue_op = q.enqueue_many((elems,))
@@ -586,7 +586,7 @@ class FIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(50, dtypes_lib.float32, shapes=())
initial_elements = [10.0] * 49
q.enqueue_many((initial_elements,)).run()
@@ -619,7 +619,7 @@ class FIFOQueueTest(test.TestCase):
self.assertTrue(elem in (10.0, 20.0))
def testMixtureOfEnqueueAndEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, shapes=())
enqueue_placeholder = array_ops.placeholder(dtypes_lib.int32, shape=())
enqueue_op = q.enqueue((enqueue_placeholder,))
@@ -655,7 +655,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testMixtureOfDequeueAndDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.int32, shapes=())
enqueue_op = q.enqueue_many((np.arange(250, dtype=np.int32),))
dequeued_t = q.dequeue()
@@ -689,7 +689,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -716,7 +716,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
def testBlockingDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -743,7 +743,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
def testDequeueManyWithTensorParameter(self):
- with self.test_session():
+ with self.cached_session():
# Define a first queue that contains integer counts.
dequeue_counts = [random.randint(1, 10) for _ in range(100)]
count_q = data_flow_ops.FIFOQueue(100, dtypes_lib.int32, ())
@@ -768,7 +768,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(elems, dequeued_elems)
def testDequeueFromClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -786,7 +786,7 @@ class FIFOQueueTest(test.TestCase):
dequeued_t.eval()
def testBlockingDequeueFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -812,7 +812,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
close_op = q.close()
dequeued_t = q.dequeue()
@@ -832,7 +832,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueManyFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -857,7 +857,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueManyButNotAllFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -882,7 +882,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testDequeueUpToFromClosedQueueReturnsRemainder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -904,7 +904,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32, ())
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -941,7 +941,7 @@ class FIFOQueueTest(test.TestCase):
close_thread.join()
def testClosedBlockingDequeueManyRestoresPartialBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, (dtypes_lib.float32, dtypes_lib.float32), (
(), ()))
elems_a = [1.0, 2.0, 3.0]
@@ -974,7 +974,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingDequeueManyFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
close_op = q.close()
dequeued_t = q.dequeue_many(4)
@@ -994,7 +994,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueUpToFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, ())
close_op = q.close()
dequeued_t = q.dequeue_up_to(4)
@@ -1014,7 +1014,7 @@ class FIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
close_op = q.close()
@@ -1027,7 +1027,7 @@ class FIFOQueueTest(test.TestCase):
enqueue_op.run()
def testEnqueueManyToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1041,7 +1041,7 @@ class FIFOQueueTest(test.TestCase):
enqueue_op.run()
def testBlockingEnqueueToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1064,7 +1064,7 @@ class FIFOQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueManyToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1091,7 +1091,7 @@ class FIFOQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueBeforeClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1128,7 +1128,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingEnqueueManyBeforeClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(4, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1161,7 +1161,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(elem, dequeued_t.eval())
def testDoesNotLoseValue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.FIFOQueue(1, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
size_t = q.size()
@@ -1171,7 +1171,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(size_t.eval(), [1])
def testSharedQueueSameSession(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.FIFOQueue(
1, dtypes_lib.float32, shared_name="shared_queue")
q1.enqueue((10.0,)).run()
@@ -1201,7 +1201,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(q2_size_t.eval(), [0])
def testIncompatibleSharedQueueErrors(self):
- with self.test_session():
+ with self.cached_session():
q_a_1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shared_name="q_a")
q_a_2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32, shared_name="q_a")
q_a_1.queue_ref.op.run()
@@ -1244,7 +1244,7 @@ class FIFOQueueTest(test.TestCase):
q_f_2.queue_ref.op.run()
def testSelectQueue(self):
- with self.test_session():
+ with self.cached_session():
num_queues = 10
qlist = list()
for _ in xrange(num_queues):
@@ -1257,7 +1257,7 @@ class FIFOQueueTest(test.TestCase):
self.assertEqual(q.dequeue().eval(), 10.0)
def testSelectQueueOutOfRange(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
q2 = data_flow_ops.FIFOQueue(15, dtypes_lib.float32)
enq_q = data_flow_ops.FIFOQueue.from_list(3, [q1, q2])
@@ -1281,7 +1281,7 @@ class FIFOQueueTest(test.TestCase):
sess.run(enqueue_many_op)
def testResetOfBlockingOperation(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q_empty = data_flow_ops.FIFOQueue(5, dtypes_lib.float32, ())
dequeue_op = q_empty.dequeue()
dequeue_many_op = q_empty.dequeue_many(1)
@@ -1309,7 +1309,7 @@ class FIFOQueueTest(test.TestCase):
t.join()
def testBigEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(5, dtypes_lib.int32, ((),))
elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
enq = q.enqueue_many((elem,))
@@ -1354,7 +1354,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elem, results)
def testBigDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(2, dtypes_lib.int32, ((),))
elem = np.arange(4, dtype=np.int32)
enq_list = [q.enqueue((e,)) for e in elem]
@@ -1380,7 +1380,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(elem, results)
def testDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dtypes = [
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, dtypes_lib.int64,
@@ -1411,7 +1411,7 @@ class FIFOQueueTest(test.TestCase):
self.assertAllEqual(input_elem, output_elem)
def testDequeueEnqueueFail(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
a = q.dequeue()
b = control_flow_ops.Assert(False, ["Before enqueue"])
@@ -1474,7 +1474,7 @@ class FIFOQueueDictTest(test.TestCase):
self.assertEqual(["i", "f"], q.names)
def testEnqueueDequeueOneComponent(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(
10, dtypes_lib.float32, shapes=((),), names="f")
# Verify that enqueue() checks that when using names we must enqueue a
@@ -1519,7 +1519,7 @@ class FIFOQueueDictTest(test.TestCase):
self.assertEqual([40.0, 50.0], list(f))
def testEnqueueDequeueMultipleComponent(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32, dtypes_lib.string),
shapes=((), (), ()),
@@ -1600,7 +1600,7 @@ class FIFOQueueWithTimeoutTest(test.TestCase):
sess.run(dequeued_t)
def testReusableAfterTimeout(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
dequeued_t = q.dequeue()
enqueue_op = q.enqueue(37)
diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
index faac7d8365..f89d2062f1 100644
--- a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
@@ -127,7 +127,7 @@ class FractionalAvgTest(test.TestCase):
Returns:
None
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p, r, c = nn_ops.fractional_avg_pool(
input_tensor,
pooling_ratio,
@@ -160,7 +160,7 @@ class FractionalAvgTest(test.TestCase):
overlapping))
rand_mat = self._PRNG.randint(10, size=tensor_shape)
pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p, r, c = nn_ops.fractional_avg_pool(
rand_mat.astype(np.float32),
pooling_ratio,
@@ -234,7 +234,7 @@ class FractionalAvgTest(test.TestCase):
[4, 4, 5, 9, 7, 2]
])
# pyformat: enable
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Since deterministic = True, seed and seed2 are fixed. Therefore r, and c
# are the same each time. We can have an expected result precomputed.
# r = [0, 2, 4, 6]
@@ -314,7 +314,7 @@ class FractionalAvgTest(test.TestCase):
def testDifferentInputTensorShape(self):
"""Runs the operation in one session with different input tensor shapes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_holder = array_ops.placeholder(dtypes.float32,
[None, None, None, 3])
pooling_ratio = [1, 1.5, 1.5, 1]
@@ -389,7 +389,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
num_cols = col_window_size * 7
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(
self._GenerateRandomInputTensor(input_shape).astype(
np.float32))
@@ -428,7 +428,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
num_cols = (col_window_size - 1) * 7 + 1
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(
self._GenerateRandomInputTensor(input_shape).astype(
np.float32))
@@ -468,7 +468,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
for pseudo_random in True, False:
for overlapping in True, False:
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
input_tensor,
@@ -501,7 +501,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
for num_channels in [1, 3]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
input_data = self._GenerateRandomInputTensor(input_shape)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
input_tensor,
@@ -532,7 +532,7 @@ class FractionalAvgPoolGradTest(test.TestCase):
overlapping = True
pseudo_random = False
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_avg_pool(
input_tensor,
diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
index 6477c9ebc4..9b94ca8554 100644
--- a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
@@ -127,7 +127,7 @@ class FractionalMaxPoolTest(test.TestCase):
Returns:
None
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p, r, c = nn_ops.fractional_max_pool(
input_tensor,
pooling_ratio,
@@ -160,7 +160,7 @@ class FractionalMaxPoolTest(test.TestCase):
overlapping))
rand_mat = self._PRNG.randint(10, size=tensor_shape)
pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
p, r, c = nn_ops.fractional_max_pool(
rand_mat,
pooling_ratio,
@@ -285,7 +285,7 @@ class FractionalMaxPoolTest(test.TestCase):
def testDifferentInputTensorShape(self):
"""Runs the operation in one session with different input tensor shapes."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_holder = array_ops.placeholder(dtypes.float32,
[None, None, None, 3])
pooling_ratio = [1, 1.5, 1.5, 1]
@@ -374,7 +374,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
num_cols = col_window_size * 7
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(
self._GenerateUniqueRandomInputTensor(input_shape))
window_size = [1, row_window_size, col_window_size, 1]
@@ -409,7 +409,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
num_cols = (col_window_size - 1) * 7 + 1
for num_channels in [1, 2]:
input_shape = (num_batches, num_rows, num_cols, num_channels)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(
self._GenerateUniqueRandomInputTensor(input_shape))
window_size = [1, row_window_size, col_window_size, 1]
@@ -447,7 +447,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
for pseudo_random in True, False:
for overlapping in True, False:
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
input_tensor,
@@ -482,7 +482,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
input_data = self._GenerateUniqueRandomInputTensor(input_shape)
# Add some randomness to make input_data not so 'integer'
input_data += self._PRNG.random_sample(input_shape)
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
input_tensor,
@@ -515,7 +515,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
overlapping = True
pseudo_random = False
- with self.test_session() as _:
+ with self.cached_session() as _:
input_tensor = constant_op.constant(input_data, shape=input_shape)
output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool(
input_tensor,
@@ -579,7 +579,7 @@ class FractionalMaxPoolGradTest(test.TestCase):
0.0, 0.0, 0.0, 0.0,
6.0, 0.0, 21.0, 0.0],
input_size) # pyformat: disable
- with self.test_session() as _:
+ with self.cached_session() as _:
# Test when overlapping is False
input_tensor = constant_op.constant(input_data, shape=input_size)
output_tensor = constant_op.constant(
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index e39daf1371..30d11852c7 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -735,7 +735,7 @@ class FunctionalOpsTest(test.TestCase):
def Run(sess, n):
return sess.run(functional_ops.While([n, 0.], Cond, Body))[1]
- with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ with self.session(graph=g, use_gpu=use_gpu) as sess:
self.assertAllEqual(Run(sess, 20.), 210.)
self.assertAllEqual(Run(sess, 100.), 5050.)
@@ -765,7 +765,7 @@ class FunctionalOpsTest(test.TestCase):
fetch = outputs[1]
else:
fetch = "my_while:1"
- with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ with self.session(graph=g, use_gpu=use_gpu) as sess:
return sess.run(fetch)
self.assertAllEqual(Run(20., False), 210.)
@@ -793,7 +793,7 @@ class FunctionalOpsTest(test.TestCase):
def BodyReturnsTooManyArgs(n, x):
return n - 1, x + n, x
- with self.test_session(graph=g, use_gpu=use_gpu):
+ with self.session(graph=g, use_gpu=use_gpu):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Expected a single scalar.*got 2 tensors."):
@@ -818,7 +818,7 @@ class FunctionalOpsTest(test.TestCase):
def Body(n, x):
return n - 1, x + n
- with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ with self.session(graph=g, use_gpu=use_gpu) as sess:
n = array_ops.placeholder(dtypes.float32)
_, result = functional_ops.While([n, 0.], Cond, Body)
c = constant_op.constant(37.)
@@ -831,7 +831,7 @@ class FunctionalOpsTest(test.TestCase):
def _tfSum(self, use_gpu, rewrite_with_while):
with ops.Graph().as_default() as g:
- with self.test_session(graph=g, use_gpu=use_gpu) as sess:
+ with self.session(graph=g, use_gpu=use_gpu) as sess:
@function.Defun(dtypes.int32, dtypes.float32)
def Body(n, x):
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py
index 033fa95935..85bf969068 100644
--- a/tensorflow/python/kernel_tests/gather_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_op_test.py
@@ -147,7 +147,7 @@ class GatherTest(test.TestCase):
def testString(self):
params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([b"qwer", b"uiop"],
array_ops.gather(params, 1, axis=0).eval())
self.assertAllEqual([b"asdf", b"qwer"],
@@ -157,7 +157,7 @@ class GatherTest(test.TestCase):
for unsigned_type in (dtypes.uint32, dtypes.uint64):
params = self._buildParams(
np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([7, 8, 9],
array_ops.gather(params, 1, axis=0).eval())
self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
diff --git a/tensorflow/python/kernel_tests/gradient_correctness_test.py b/tensorflow/python/kernel_tests/gradient_correctness_test.py
index e93c6235f7..291a69ebac 100644
--- a/tensorflow/python/kernel_tests/gradient_correctness_test.py
+++ b/tensorflow/python/kernel_tests/gradient_correctness_test.py
@@ -30,7 +30,7 @@ from tensorflow.python.platform import test
class GradientCorrectnessTest(test.TestCase):
def testMultipleOutputChainedGradients(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant(1.0, dtype=dtypes.float32)
yexp = math_ops.exp(x)
yexplog = math_ops.log(yexp)
@@ -43,13 +43,13 @@ class GradientCorrectnessTest(test.TestCase):
def testIdentityGradient(self):
x = constant_op.constant(3.)
dx_dx, = gradients_impl.gradients(x, x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllClose(1., sess.run(dx_dx))
def testIntegerIdentityGradient(self):
x = constant_op.constant(3)
dx_dx, = gradients_impl.gradients(x, x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllClose(1, sess.run(dx_dx))
def testGradientWithIntegerPath(self):
@@ -57,7 +57,7 @@ class GradientCorrectnessTest(test.TestCase):
k = math_ops.to_float(math_ops.to_int32(x))
y = x * k
dy_dx, = gradients_impl.gradients(y, x)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllClose([3., 4.], sess.run(dy_dx))
def testNoIntegerGradient1(self):
diff --git a/tensorflow/python/kernel_tests/identity_n_op_py_test.py b/tensorflow/python/kernel_tests/identity_n_op_py_test.py
index 408b173981..518733cd8e 100644
--- a/tensorflow/python/kernel_tests/identity_n_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_n_op_py_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class IdentityNOpTest(test.TestCase):
def testInt32String_6(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[value0, value1] = sess.run(
array_ops.identity_n([[1, 2, 3, 4, 5, 6],
[b"a", b"b", b"C", b"d", b"E", b"f", b"g"]]))
@@ -37,7 +37,7 @@ class IdentityNOpTest(test.TestCase):
np.array([b"a", b"b", b"C", b"d", b"E", b"f", b"g"]), value1)
def testInt32_shapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inp0 = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
inp1 = constant_op.constant([11, 21, 31, 41, 51, 61], shape=[3, 2])
inp2 = constant_op.constant(
@@ -52,12 +52,12 @@ class IdentityNOpTest(test.TestCase):
def testString(self):
source = [b"A", b"b", b"C", b"d", b"E", b"f"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
[value] = sess.run(array_ops.identity_n([source]))
self.assertAllEqual(source, value)
def testIdentityShape(self):
- with self.test_session():
+ with self.cached_session():
shape = [2, 3]
array_2x3 = [[1, 2, 3], [6, 5, 4]]
tensor = constant_op.constant(array_2x3)
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
index 49fb76d5b4..37f9f716f8 100644
--- a/tensorflow/python/kernel_tests/identity_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -31,24 +31,24 @@ from tensorflow.python.platform import test
class IdentityOpTest(test.TestCase):
def testInt32_6(self):
- with self.test_session():
+ with self.cached_session():
value = array_ops.identity([1, 2, 3, 4, 5, 6]).eval()
self.assertAllEqual(np.array([1, 2, 3, 4, 5, 6]), value)
def testInt32_2_3(self):
- with self.test_session():
+ with self.cached_session():
inp = constant_op.constant([10, 20, 30, 40, 50, 60], shape=[2, 3])
value = array_ops.identity(inp).eval()
self.assertAllEqual(np.array([[10, 20, 30], [40, 50, 60]]), value)
def testString(self):
source = [b"A", b"b", b"C", b"d", b"E", b"f"]
- with self.test_session():
+ with self.cached_session():
value = array_ops.identity(source).eval()
self.assertAllEqual(source, value)
def testIdentityShape(self):
- with self.test_session():
+ with self.cached_session():
shape = [2, 3]
array_2x3 = [[1, 2, 3], [6, 5, 4]]
tensor = constant_op.constant(array_2x3)
@@ -59,7 +59,7 @@ class IdentityOpTest(test.TestCase):
array_ops.identity(np.array(array_2x3)).get_shape())
def testRefIdentityShape(self):
- with self.test_session():
+ with self.cached_session():
shape = [2, 3]
tensor = variables.Variable(
constant_op.constant(
diff --git a/tensorflow/python/kernel_tests/in_topk_op_test.py b/tensorflow/python/kernel_tests/in_topk_op_test.py
index fafeea8ec0..6fdb497bc6 100644
--- a/tensorflow/python/kernel_tests/in_topk_op_test.py
+++ b/tensorflow/python/kernel_tests/in_topk_op_test.py
@@ -30,7 +30,7 @@ class InTopKTest(test.TestCase):
def _validateInTopK(self, predictions, target, k, expected):
np_ans = np.array(expected)
- with self.test_session():
+ with self.cached_session():
precision = nn_ops.in_top_k(predictions, target, k)
out = precision.eval()
self.assertAllClose(np_ans, out)
@@ -65,7 +65,7 @@ class InTopKTest(test.TestCase):
def testBadTarget(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
target = [0, 80000]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"target.*out of range"):
nn_ops.in_top_k(predictions, target, 2).eval()
@@ -75,7 +75,7 @@ class InTopKTest(test.TestCase):
target = [0, 2]
k = constant_op.constant(3)
np_ans = np.array([False, True])
- with self.test_session():
+ with self.cached_session():
precision = nn_ops.in_top_k(predictions, target, k)
out = precision.eval()
self.assertAllClose(np_ans, out)
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index f6097ad489..292679e4b9 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -343,7 +343,7 @@ class UniformUnitScalingInitializationTest(test.TestCase):
def testZeroSize(self):
shape = [0, 2]
- with self.test_session():
+ with self.cached_session():
x = variable_scope.get_variable(
"x",
shape=shape,
@@ -522,7 +522,7 @@ class LinSpaceTest(test.TestCase):
def _LinSpace(self, start, stop, num):
# NOTE(touts): Needs to pass a graph to get a new session each time.
with ops.Graph().as_default() as graph:
- with self.test_session(graph=graph, force_gpu=self.force_gpu):
+ with self.session(graph=graph, force_gpu=self.force_gpu):
tf_ans = math_ops.linspace(start, stop, num, name="linspace")
self.assertEqual([num], tf_ans.get_shape())
return tf_ans.eval()
@@ -606,7 +606,7 @@ class OrthogonalInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.orthogonal_initializer()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[5])
def testGain(self):
@@ -614,7 +614,7 @@ class OrthogonalInitializerTest(test.TestCase):
for dtype in [dtypes.float32, dtypes.float64]:
init1 = init_ops.orthogonal_initializer(seed=1, dtype=dtype)
init2 = init_ops.orthogonal_initializer(gain=3.14, seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -624,7 +624,7 @@ class OrthogonalInitializerTest(test.TestCase):
for shape in [(10, 10), (10, 9, 8), (100, 5, 5), (50, 40), (40, 50)]:
init = init_ops.orthogonal_initializer(dtype=dtype)
tol = 1e-5 if dtype == dtypes.float32 else 1e-12
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
# Check the shape
t = init(shape).eval()
self.assertAllEqual(shape, t.shape)
@@ -663,7 +663,7 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.convolutional_delta_orthogonal()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[3, 3, 6, 5])
def testGain(self):
@@ -672,7 +672,7 @@ class ConvolutionDeltaOrthogonalInitializerTest(test.TestCase):
init1 = init_ops.convolutional_delta_orthogonal(seed=1, dtype=dtype)
init2 = init_ops.convolutional_delta_orthogonal(gain=3.14,
seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -763,7 +763,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_1d()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[3, 6, 5])
def testGain(self):
@@ -772,7 +772,7 @@ class ConvolutionOrthogonal1dInitializerTest(test.TestCase):
init1 = init_ops.convolutional_orthogonal_1d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_1d(gain=3.14,
seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -877,7 +877,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_2d()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[3, 3, 6, 5])
def testGain(self):
@@ -886,7 +886,7 @@ class ConvolutionOrthogonal2dInitializerTest(test.TestCase):
init1 = init_ops.convolutional_orthogonal_2d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_2d(gain=3.14,
seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -972,7 +972,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
def testInvalidShape(self):
init1 = init_ops.convolutional_orthogonal_3d()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init1, shape=[3, 3, 3, 6, 5])
def testGain(self):
@@ -981,7 +981,7 @@ class ConvolutionOrthogonal3dInitializerTest(test.TestCase):
init1 = init_ops.convolutional_orthogonal_3d(seed=1, dtype=dtype)
init2 = init_ops.convolutional_orthogonal_3d(gain=3.14,
seed=1, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
t1 = init1(shape).eval()
t2 = init2(shape).eval()
return np.allclose(t1, t2 / 3.14, rtol=1e-15, atol=1e-15)
@@ -1080,7 +1080,7 @@ class IdentityInitializerTest(test.TestCase):
def testInvalidShape(self):
init = init_ops.identity_initializer()
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertRaises(ValueError, init, shape=[5, 7, 7])
self.assertRaises(ValueError, init, shape=[5])
self.assertRaises(ValueError, init, shape=[])
@@ -1088,7 +1088,7 @@ class IdentityInitializerTest(test.TestCase):
def testNonSquare(self):
init = init_ops.identity_initializer()
shape = (10, 5)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertAllClose(init(shape).eval(), np.eye(*shape))
def testGain(self):
@@ -1096,16 +1096,16 @@ class IdentityInitializerTest(test.TestCase):
for dtype in [dtypes.float32, dtypes.float64]:
init_default = init_ops.identity_initializer(dtype=dtype)
init_custom = init_ops.identity_initializer(gain=0.9, dtype=dtype)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertAllClose(init_default(shape).eval(), np.eye(*shape))
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
self.assertAllClose(init_custom(shape).eval(), np.eye(*shape) * 0.9)
def testPartitions(self):
shape = (10, 10)
init = init_ops.identity_initializer()
partitioner = partitioned_variables.variable_axis_size_partitioner(1)
- with self.test_session(graph=ops.Graph(), use_gpu=True):
+ with self.session(graph=ops.Graph(), use_gpu=True):
with variable_scope.variable_scope(
"foo", partitioner=partitioner, initializer=init):
v = array_ops.identity(variable_scope.get_variable("bar", shape=shape))
diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py
index 6e894365af..90759c23ae 100644
--- a/tensorflow/python/kernel_tests/inplace_ops_test.py
+++ b/tensorflow/python/kernel_tests/inplace_ops_test.py
@@ -153,7 +153,7 @@ class InplaceOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose(vy, vz)
def testError(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"must be a vector"):
_ = inplace_ops.inplace_update([[1.]], [[0]], [[10]]).eval()
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
index 61944f7e31..afa24195cb 100644
--- a/tensorflow/python/kernel_tests/io_ops_test.py
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -37,7 +37,7 @@ class IoOpsTest(test.TestCase):
with tempfile.NamedTemporaryFile(
prefix='ReadFileTest', dir=self.get_temp_dir(), delete=False) as temp:
temp.write(contents)
- with self.test_session():
+ with self.cached_session():
read = io_ops.read_file(temp.name)
self.assertEqual([], read.get_shape())
self.assertEqual(read.eval(), contents)
@@ -51,7 +51,7 @@ class IoOpsTest(test.TestCase):
prefix='WriteFileTest', dir=self.get_temp_dir(),
delete=False) as temp:
pass
- with self.test_session() as sess:
+ with self.cached_session() as sess:
w = io_ops.write_file(temp.name, contents)
sess.run(w)
with open(temp.name, 'rb') as f:
@@ -65,7 +65,7 @@ class IoOpsTest(test.TestCase):
contents = compat.as_bytes(contents)
subdir = os.path.join(self.get_temp_dir(), 'subdir1')
filepath = os.path.join(subdir, 'subdir2', 'filename')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
w = io_ops.write_file(filepath, contents)
sess.run(w)
with open(filepath, 'rb') as f:
@@ -88,7 +88,7 @@ class IoOpsTest(test.TestCase):
prefix=c, dir=self.get_temp_dir(), delete=True) for c in cases
]
- with self.test_session():
+ with self.cached_session():
# Test exact match without wildcards.
for f in files:
self.assertEqual(
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index f4ec3e3996..be2e31cb5a 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -25,6 +25,22 @@ cuda_py_test(
)
cuda_py_test(
+ name = "linear_operator_addition_test",
+ size = "small",
+ srcs = ["linear_operator_addition_test.py"],
+ additional_deps = [
+ "//tensorflow/python/ops/linalg",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "linear_operator_block_diag_test",
size = "medium",
srcs = ["linear_operator_block_diag_test.py"],
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
index 6a72df6dfd..cf56168d63 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
@@ -19,10 +19,10 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.linalg.python.ops import linear_operator_addition
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_addition
from tensorflow.python.platform import test
linalg = linalg_lib
@@ -61,7 +61,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
op_a = linalg.LinearOperatorDiag([1., 1.])
op_sum = add_operators([op_a])
self.assertEqual(1, len(op_sum))
- self.assertTrue(op_sum[0] is op_a)
+ self.assertIs(op_sum[0], op_a)
def test_at_least_one_operators_required(self):
with self.assertRaisesRegexp(ValueError, "must contain at least one"):
@@ -76,11 +76,11 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
[1., 1.], is_positive_definite=True, name="A")
op_b = linalg.LinearOperatorDiag(
[2., 2.], is_positive_definite=True, name="B")
- with self.test_session():
+ with self.cached_session():
op_sum = add_operators([op_a, op_b])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
+ self.assertIsInstance(op, linalg_lib.LinearOperatorDiag)
self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
# Adding positive definite operators produces positive def.
self.assertTrue(op.is_positive_definite)
@@ -98,7 +98,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
[2., 2.], is_positive_definite=True, name="op2")
op3 = linalg.LinearOperatorDiag(
[3., 3.], is_positive_definite=True, name="op3")
- with self.test_session():
+ with self.cached_session():
op_sum = add_operators([op1, op2, op3])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
@@ -121,11 +121,11 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
name="tril")
op3 = linalg.LinearOperatorDiag(
[3., 3.], is_non_singular=True, name="diag_b")
- with self.test_session():
+ with self.cached_session():
op_sum = add_operators([op1, op2, op3])
self.assertEqual(1, len(op_sum))
op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular))
+ self.assertIsInstance(op, linalg_lib.LinearOperatorLowerTriangular)
self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
# The diag operators will be self-adjoint (because real and diagonal).
@@ -143,11 +143,11 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase):
op2 = linalg.LinearOperatorLowerTriangular(
[[2., 0.], [1.5, 2.]], name="tril")
op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
- with self.test_session():
+ with self.cached_session():
op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
self.assertEqual(1, len(op_sum))
op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorFullMatrix))
+ self.assertIsInstance(op, linalg_lib.LinearOperatorFullMatrix)
self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
self.assertEqual("my_operator", op.name)
@@ -185,7 +185,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase):
# _BadAdder) was never reached.
op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorDiag))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag)
def test_tier_1_additions_done_by_tier_1(self):
diag1 = linalg.LinearOperatorDiag([1.])
@@ -200,7 +200,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase):
# _BadAdder) was never reached.
op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
diag1 = linalg.LinearOperatorDiag([1.])
@@ -217,7 +217,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase):
# Tier 2 was never used (therefore, _BadAdder didn't raise).
op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
def test_cannot_add_everything_so_return_more_than_one_operator(self):
diag1 = linalg.LinearOperatorDiag([1.])
@@ -233,7 +233,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase):
self.assertEqual(2, len(op_sum))
found_diag = False
found_tril = False
- with self.test_session():
+ with self.cached_session():
for op in op_sum:
if isinstance(op, linalg.LinearOperatorDiag):
found_diag = True
@@ -271,9 +271,9 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
self.assertTrue(self._adder.can_add(id1, id2))
operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
@@ -289,9 +289,9 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
self.assertTrue(self._adder.can_add(id1, id2))
operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(3.2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
@@ -308,9 +308,9 @@ class AddAndReturnScaledIdentityTest(test.TestCase):
self.assertTrue(self._adder.can_add(id1, id2))
operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(1.2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
@@ -332,9 +332,9 @@ class AddAndReturnDiagTest(test.TestCase):
self.assertTrue(self._adder.can_add(id1, id2))
operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
+ self.assertIsInstance(operator, linalg.LinearOperatorDiag)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(2 *
linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
operator.to_dense().eval())
@@ -352,9 +352,9 @@ class AddAndReturnDiagTest(test.TestCase):
self.assertTrue(self._adder.can_add(op1, op2))
operator = self._adder.add(op1, op2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
+ self.assertIsInstance(operator, linalg.LinearOperatorDiag)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
operator.to_dense().eval())
@@ -377,9 +377,9 @@ class AddAndReturnTriLTest(test.TestCase):
self.assertTrue(self._adder.can_add(diag, diag))
self.assertTrue(self._adder.can_add(diag, tril))
operator = self._adder.add(diag, tril, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular))
+ self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
self.assertTrue(operator.is_positive_definite)
self.assertTrue(operator.is_non_singular)
@@ -399,9 +399,9 @@ class AddAndReturnMatrixTest(test.TestCase):
self.assertTrue(self._adder.can_add(diag1, diag2))
operator = self._adder.add(diag1, diag2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorFullMatrix))
+ self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
self.assertFalse(operator.is_positive_definite)
self.assertFalse(operator.is_non_singular)
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index 7261d4bb3b..f1e151ebd8 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -37,8 +37,10 @@ class LinearOperatorCirculantBaseTest(object):
"""Common class for circulant tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
@@ -110,8 +112,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype)
@@ -121,7 +122,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -171,8 +172,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
@@ -182,7 +182,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -217,8 +217,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
@@ -228,7 +227,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -238,7 +237,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
np.testing.assert_allclose(0, imag_matrix.eval(), rtol=0, atol=eps * 3)
def test_simple_positive_real_spectrum_gives_self_adjoint_pos_def_oper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
spectrum = math_ops.cast([6., 4, 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -250,7 +249,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
operator.assert_self_adjoint().run() # Should not fail
def test_defining_operator_using_real_convolution_kernel(self):
- with self.test_session():
+ with self.cached_session():
convolution_kernel = [1., 2., 1.]
spectrum = math_ops.fft(
math_ops.cast(convolution_kernel, dtypes.complex64))
@@ -266,7 +265,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
def test_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
# Make spectrum the FFT of a real convolution kernel h. This ensures that
# spectrum is Hermitian.
h = linear_operator_test_util.random_normal(shape=(3, 4))
@@ -281,7 +280,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def test_convolution_kernel_same_as_first_row_of_to_dense(self):
spectrum = [[3., 2., 1.], [2., 1.5, 1.]]
- with self.test_session():
+ with self.cached_session():
operator = linalg.LinearOperatorCirculant(spectrum)
h = operator.convolution_kernel()
c = operator.to_dense()
@@ -293,27 +292,27 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def test_assert_non_singular_fails_for_singular_operator(self):
spectrum = math_ops.cast([0, 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
spectrum = math_ops.cast([-3j, 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_non_singular().run() # Should not fail
def test_assert_positive_definite_fails_for_non_positive_definite(self):
spectrum = math_ops.cast([6., 4, 2j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Not positive definite"):
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_fail_when_pos_def(self):
spectrum = math_ops.cast([6., 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_positive_definite().run() # Should not fail
def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
@@ -331,8 +330,10 @@ class LinearOperatorCirculant2DBaseTest(object):
"""Common class for 2D circulant tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
@@ -446,8 +447,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
@@ -482,8 +482,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
@@ -493,7 +492,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
return operator, mat
def test_real_hermitian_spectrum_gives_real_symmetric_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = [[1., 2., 2.], [3., 4., 4.], [3., 4., 4.]]
operator = linalg.LinearOperatorCirculant(spectrum)
@@ -510,7 +509,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
self.assertAllClose(matrix, matrix_transpose, atol=0)
def test_real_spectrum_gives_self_adjoint_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal(
shape=(3, 3), dtype=dtypes.float32)
@@ -526,27 +525,27 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
def test_assert_non_singular_fails_for_singular_operator(self):
spectrum = math_ops.cast([[0, 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
spectrum = math_ops.cast([[-3j, 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_non_singular().run() # Should not fail
def test_assert_positive_definite_fails_for_non_positive_definite(self):
spectrum = math_ops.cast([[6., 4], [2j, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Not positive definite"):
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_fail_when_pos_def(self):
spectrum = math_ops.cast([[6., 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_positive_definite().run() # Should not fail
def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
@@ -574,13 +573,15 @@ class LinearOperatorCirculant3DTest(test.TestCase):
"""Simple test of the 3D case. See also the 1D and 2D tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
def test_real_spectrum_gives_self_adjoint_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32)
@@ -597,7 +598,7 @@ class LinearOperatorCirculant3DTest(test.TestCase):
self.assertAllClose(matrix, matrix_h)
def test_defining_operator_using_real_convolution_kernel(self):
- with self.test_session():
+ with self.cached_session():
convolution_kernel = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32)
# Convolution kernel is real ==> spectrum is Hermitian.
@@ -615,7 +616,7 @@ class LinearOperatorCirculant3DTest(test.TestCase):
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
def test_defining_spd_operator_by_taking_real_part(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# S is real and positive.
s = linear_operator_test_util.random_uniform(
shape=(10, 2, 3, 4), dtype=dtypes.float32, minval=1., maxval=2.)
diff --git a/tensorflow/python/kernel_tests/linalg_grad_test.py b/tensorflow/python/kernel_tests/linalg_grad_test.py
index 0e4e58409e..e52f303fe0 100644
--- a/tensorflow/python/kernel_tests/linalg_grad_test.py
+++ b/tensorflow/python/kernel_tests/linalg_grad_test.py
@@ -40,7 +40,7 @@ def _AddTest(test, op_name, testcase_name, fn):
class ShapeTest(test_lib.TestCase):
def testBatchGradientUnknownSize(self):
- with self.test_session():
+ with self.cached_session():
batch_size = constant_op.constant(3)
matrix_size = constant_op.constant(4)
batch_identity = array_ops.tile(
@@ -120,7 +120,7 @@ def _GetMatrixBinaryFunctorGradientTest(functor_,
delta = epsilon**(1.0 / 3.0)
# tolerance obtained by looking at actual differences using
# np.linalg.norm(theoretical-numerical, np.inf) on -mavx build
- tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.04
+ tol = 1e-6 if dtype_ == np.float64 else float32_tol_fudge * 0.05
# The gradients for a and b may be of very different magnitudes,
# so to not get spurious failures we test them separately.
for factor, factor_init in [a, a_np], [b, b_np]:
diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py
index 2f28d37eff..aa17f727d0 100644
--- a/tensorflow/python/kernel_tests/linalg_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg_ops_test.py
@@ -128,7 +128,7 @@ class AdjointTest(test.TestCase):
matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
6 + 6j]]).astype(dtype)
expected_transposed = np.conj(matrix_np.T)
- with self.test_session():
+ with self.cached_session():
matrix = ops.convert_to_tensor(matrix_np)
transposed = linalg.adjoint(matrix)
self.assertEqual((3, 2), transposed.get_shape())
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index ee86cf0b24..baeb40dd63 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -42,7 +42,7 @@ class ListDiffTest(test.TestCase):
out = [compat.as_bytes(str(a)) for a in out]
for diff_func in [array_ops.setdiff1d]:
for index_dtype in [dtypes.int32, dtypes.int64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
out_tensor, idx_tensor = diff_func(x_tensor, y_tensor,
diff --git a/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
new file mode 100644
index 0000000000..0e8197dccb
--- /dev/null
+++ b/tensorflow/python/kernel_tests/logging_ops_logging_level_test.py
@@ -0,0 +1,70 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class PrintV2LoggingLevelTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogInfo(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.info)
+ self.evaluate(print_op)
+ self.assertTrue("I" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogWarning(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.warning)
+ self.evaluate(print_op)
+ self.assertTrue("W" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorLogError(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=tf_logging.error)
+ self.evaluate(print_op)
+ self.assertTrue("E" in printed.contents())
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue(expected in printed.contents())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index e635a71c78..4beddd00bb 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -18,20 +18,28 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import sys
+
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
class LoggingOpsTest(test.TestCase):
def testAssertDivideByZero(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
epsilon = ops.convert_to_tensor(1e-20)
x = ops.convert_to_tensor(0.0)
y = ops.convert_to_tensor(1.0)
@@ -57,6 +65,269 @@ class LoggingOpsTest(test.TestCase):
out.eval()
+class PrintV2Test(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensor(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorVarySummarize(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=1)
+ self.evaluate(print_op)
+
+ expected = "[0 ... 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=2)
+ self.evaluate(print_op)
+
+ expected = "[0 1 ... 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=3)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, summarize=-1)
+ self.evaluate(print_op)
+
+ expected = "[0 1 2 3 4 5 6 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneVariable(self):
+ with self.cached_session():
+ var = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(var)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoVariablesInStructWithAssignAdd(self):
+ with self.cached_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ with self.captureWritesToStream(sys.stderr) as printed:
+ self.evaluate(plus_one)
+ print_op = logging_ops.print_v2(var_one, {"second": var_two})
+ self.evaluate(print_op)
+ expected = "3.14 {'second': [0 1 2 ... 7 8 9]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintTwoTensors(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor, tensor * 10)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9] [0 10 20 ... 70 80 90]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintPlaceholderGeneration(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2("{}6", {"{}": tensor * 10})
+ self.evaluate(print_op)
+ expected = "{}6 {'{}': [0 10 20 ... 70 80 90]}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintNoTensors(self):
+ with self.cached_session():
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(23, [23, 5], {"6": 12})
+ self.evaluate(print_op)
+ expected = "23 [23, 5] {'6': 12}"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintFloatScalar(self):
+ with self.cached_session():
+ tensor = ops.convert_to_tensor(434.43)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "434.43"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintStringScalar(self):
+ with self.cached_session():
+ tensor = ops.convert_to_tensor("scalar")
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(tensor)
+ self.evaluate(print_op)
+ expected = "scalar"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintComplexTensorStruct(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ small_tensor = constant_op.constant([0.3, 12.4, -16.1])
+ big_tensor = math_ops.mul(tensor, 10)
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(
+ "first:", tensor, "middle:",
+ {"small": small_tensor, "Big": big_tensor}, 10,
+ [tensor * 2, tensor])
+ self.evaluate(print_op)
+ # Note that the keys in the dict will always be sorted,
+ # so 'Big' comes before 'small'
+ expected = ("first: [0 1 2 ... 7 8 9] "
+ "middle: {'Big': [0 10 20 ... 70 80 90], "
+ "'small': [0.3 12.4 -16.1]} "
+ "10 [[0 2 4 ... 14 16 18], [0 1 2 ... 7 8 9]]")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensor(self):
+ with self.cached_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2(sparse)
+ self.evaluate(print_op)
+ expected = ("'SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])'")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintSparseTensorInDataStruct(self):
+ with self.cached_session():
+ ind = [[0, 0], [1, 0], [1, 3], [4, 1], [1, 4], [3, 2], [3, 3]]
+ val = [0, 10, 13, 4, 14, 32, 33]
+ shape = [5, 6]
+
+ sparse = sparse_tensor.SparseTensor(
+ constant_op.constant(ind, dtypes.int64),
+ constant_op.constant(val, dtypes.int64),
+ constant_op.constant(shape, dtypes.int64))
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ print_op = logging_ops.print_v2([sparse])
+ self.evaluate(print_op)
+ expected = ("['SparseTensor(indices=[[0 0]\n"
+ " [1 0]\n"
+ " [1 3]\n"
+ " ...\n"
+ " [1 4]\n"
+ " [3 2]\n"
+ " [3 3]], values=[0 10 13 ... 14 32 33], shape=[5 6])']")
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintOneTensorStdout(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.captureWritesToStream(sys.stdout) as printed:
+ print_op = logging_ops.print_v2(
+ tensor, output_stream=sys.stdout)
+ self.evaluate(print_op)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testInvalidOutputStreamRaisesError(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ with self.assertRaises(ValueError):
+ print_op = logging_ops.print_v2(
+ tensor, output_stream="unknown")
+ self.evaluate(print_op)
+
+ def testPrintOpName(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ print_op = logging_ops.print_v2(tensor, name="print_name")
+ self.assertEqual(print_op.name, "print_name")
+
+ def testNoDuplicateFormatOpGraphModeAfterExplicitFormat(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ formatted_string = string_ops.string_format("{}", tensor)
+ print_op = logging_ops.print_v2(formatted_string)
+ self.evaluate(print_op)
+ graph_ops = ops.get_default_graph().get_operations()
+ format_ops = [op for op in graph_ops if op.type == "StringFormat"]
+ # Should be only 1 format_op for graph mode.
+ self.assertEqual(len(format_ops), 1)
+
+ def testPrintOneTensorEagerOnOpCreate(self):
+ with self.cached_session():
+ with context.eager_mode():
+ tensor = math_ops.range(10)
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed:
+ logging_ops.print_v2(tensor)
+ self.assertTrue((expected + "\n") in printed.contents())
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testPrintInDefunWithoutExplicitEvalOfPrint(self):
+ @function.defun
+ def f():
+ tensor = math_ops.range(10)
+ logging_ops.print_v2(tensor)
+ return tensor
+
+ expected = "[0 1 2 ... 7 8 9]"
+ with self.captureWritesToStream(sys.stderr) as printed_one:
+ x = f()
+ self.evaluate(x)
+ self.assertTrue((expected + "\n") in printed_one.contents())
+
+ # We execute the function again to make sure it doesn't only print on the
+ # first call.
+ with self.captureWritesToStream(sys.stderr) as printed_two:
+ y = f()
+ self.evaluate(y)
+ self.assertTrue((expected + "\n") in printed_two.contents())
+
+
class PrintGradientTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -65,8 +336,13 @@ class PrintGradientTest(test.TestCase):
inp_printed = logging_ops.Print(inp, [inp])
self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+ def testPrintString(self):
+ inp = constant_op.constant(2.0, shape=[100, 32])
+ inp_printed = logging_ops.Print(inp, ["hello"])
+ self.assertEqual(inp.get_shape(), inp_printed.get_shape())
+
def testPrintGradient(self):
- with self.test_session():
+ with self.cached_session():
inp = constant_op.constant(2.0, shape=[100, 32], name="in")
w = constant_op.constant(4.0, shape=[10, 100], name="w")
wx = math_ops.matmul(w, inp, name="wx")
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index 5f08339fe5..6791a03e2e 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -21,6 +21,7 @@ import os
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -29,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@@ -36,7 +38,7 @@ from tensorflow.python.training import server_lib
class HashTableOpTest(test.TestCase):
def testHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -53,8 +55,14 @@ class HashTableOpTest(test.TestCase):
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
+ exported_keys_tensor, exported_values_tensor = table.export()
+
+ self.assertItemsEqual([b"brain", b"salad", b"surgery"],
+ exported_keys_tensor.eval())
+ self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval())
+
def testHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -72,7 +80,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
@@ -90,7 +98,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
@@ -107,7 +115,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -135,7 +143,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -150,7 +158,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -173,7 +181,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -181,6 +189,11 @@ class HashTableOpTest(test.TestCase):
lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
table.init.run()
+ # Ref types do not produce a lookup signature mismatch.
+ input_string_ref = variables.Variable("brain")
+ variables.global_variables_initializer().run()
+ self.assertEqual(0, table.lookup(input_string_ref).eval())
+
input_string = constant_op.constant([1, 2, 3], dtypes.int64)
with self.assertRaises(TypeError):
table.lookup(input_string)
@@ -190,7 +203,7 @@ class HashTableOpTest(test.TestCase):
lookup_ops.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
with self.assertRaises(TypeError):
lookup_ops.HashTable(
@@ -198,7 +211,7 @@ class HashTableOpTest(test.TestCase):
dtypes.int64), default_val)
def testNotInitialized(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
table = lookup_ops.HashTable(
lookup_ops.KeyValueTensorInitializer(
@@ -211,7 +224,7 @@ class HashTableOpTest(test.TestCase):
output.eval()
def testInitializeTwice(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -223,7 +236,7 @@ class HashTableOpTest(test.TestCase):
table.init.run()
def testInitializationWithInvalidDimensions(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -261,6 +274,21 @@ class HashTableOpTest(test.TestCase):
table.init.run()
self.assertAllEqual(3, table.size().eval())
+ def testHashTableInt32String(self):
+ with self.cached_session():
+ default_val = "n/a"
+ keys = constant_op.constant([0, 1, 2], dtypes.int32)
+ values = constant_op.constant(["brain", "salad", "surgery"])
+ table = lookup_ops.HashTable(
+ lookup_ops.KeyValueTensorInitializer(keys, values), default_val)
+ table.init.run()
+
+ input_tensor = constant_op.constant([0, 1, -1])
+ output = table.lookup(input_tensor)
+
+ result = output.eval()
+ self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
+
class IndexTableFromFile(test.TestCase):
@@ -272,7 +300,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -284,7 +312,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_multicolumn_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -299,7 +327,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_multicolumn_file_custom_delimiter(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -314,7 +342,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -328,13 +356,14 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_placeholder_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
+
feed_dict = {vocabulary_placeholder.name: vocabulary_file}
lookup_ops.tables_initializer().run(feed_dict=feed_dict)
self.assertAllEqual((1, 2, 3), ids.eval())
@@ -344,7 +373,7 @@ class IndexTableFromFile(test.TestCase):
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -359,7 +388,7 @@ class IndexTableFromFile(test.TestCase):
def test_int64_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
num_oov_buckets=1,
@@ -374,7 +403,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_default_value(self):
default_value = -42
vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -385,7 +414,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
@@ -432,7 +461,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -444,7 +473,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -459,7 +488,7 @@ class IndexTableFromFile(test.TestCase):
vocabulary_file=vocabulary_file,
vocab_size=0)
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -471,7 +500,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file,
@@ -490,14 +519,14 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_table_ref_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab9.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
self.assertIsNotNone(table.table_ref)
def test_index_table_from_file_table_ref_without_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab10.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=0)
self.assertIsNotNone(table.table_ref)
@@ -506,21 +535,21 @@ class IndexTableFromFile(test.TestCase):
class KeyValueTensorInitializerTest(test.TestCase):
def test_string(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
table.init.run()
def test_int64(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int64, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
table.init.run()
def test_int32(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
dtypes.int32, dtypes.int64)
table = lookup_ops.HashTable(init, default_value=-1)
@@ -531,18 +560,25 @@ class KeyValueTensorInitializerTest(test.TestCase):
class IndexTableFromTensor(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes
def test_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ table = lookup_ops.index_table_from_tensor(
+ vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
+
+ if not context.executing_eagerly():
+ with self.assertRaises(errors_impl.OpError):
+ self.evaluate(
+ table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))))
+ else:
+ # Reinitializing a table in eager should work.
table = lookup_ops.index_table_from_tensor(
vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
- ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
-
- self.assertRaises(errors_impl.OpError, ids.eval)
- lookup_ops.tables_initializer().run()
- self.assertAllEqual((1, 2, 3), ids.eval())
+ self.evaluate(lookup_ops.tables_initializer())
+ ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
+ self.assertAllEqual((1, 2, 3), self.evaluate(ids))
def test_int32_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
ids = table.lookup(
@@ -553,7 +589,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
ids = table.lookup(
@@ -565,7 +601,7 @@ class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=["brain", "salad", "surgery"],
default_value=default_value)
@@ -576,14 +612,14 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_vocabulary_list(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
"vocabulary_list must be specified"):
lookup_ops.index_table_from_tensor(
vocabulary_list=None, num_oov_buckets=1)
def test_index_table_from_tensor_empty_vocabulary_list(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_table_from_tensor(
vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -593,7 +629,7 @@ class IndexTableFromTensor(test.TestCase):
lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup_ops.index_table_from_tensor(
vocabulary_list=["brain", "salad", "surgery"],
@@ -623,7 +659,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
type_funcs = [str, constant_op.constant]
for type_func in type_funcs:
vocabulary_file = type_func(vocabulary_path)
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file)
features = table.lookup(
@@ -636,7 +672,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_from_multicolumn_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
@@ -650,7 +686,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
@@ -665,7 +701,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_default_value(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -677,7 +713,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_small(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=2,
@@ -690,7 +726,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -702,7 +738,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -715,7 +751,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_table_from_tensor(self):
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list)
@@ -729,7 +765,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["hello", "hello"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list)
@@ -740,7 +776,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
table = lookup_ops.index_to_string_table_from_tensor(
vocabulary_list=vocabulary_list, default_value=default_value)
@@ -761,28 +797,26 @@ class InitializeTableFromFileOpTest(test.TestCase):
f.write("\n".join(values) + "\n")
return vocabulary_file
+ @test_util.run_in_graph_and_eager_modes
def testInitializeStringTable(self):
vocabulary_file = self._createVocabFile("one_column_1.txt")
+ default_value = -1
+ table = lookup_ops.HashTable(
+ lookup_ops.TextFileInitializer(
+ vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
+ dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), default_value)
+ self.evaluate(table.init)
- with self.test_session():
- default_value = -1
- table = lookup_ops.HashTable(
- lookup_ops.TextFileInitializer(
- vocabulary_file, dtypes.string,
- lookup_ops.TextFileIndex.WHOLE_LINE, dtypes.int64,
- lookup_ops.TextFileIndex.LINE_NUMBER), default_value)
- table.init.run()
-
- output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
+ output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
- result = output.eval()
- self.assertAllEqual([0, 1, -1], result)
+ result = self.evaluate(output)
+ self.assertAllEqual([0, 1, -1], result)
def testInitializeInt64Table(self):
vocabulary_file = self._createVocabFile(
"one_column_int64.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
@@ -800,7 +834,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeIndexTable(self):
vocabulary_file = self._createVocabFile("one_column_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup_ops.TextFileIndex.LINE_NUMBER
value_index = lookup_ops.TextFileIndex.WHOLE_LINE
@@ -821,7 +855,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1
value_index = 2
@@ -843,7 +877,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 2
value_index = 1
@@ -857,7 +891,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup_ops.TextFileIndex.WHOLE_LINE
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
@@ -870,7 +904,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1 # second column of the line
value_index = lookup_ops.TextFileIndex.LINE_NUMBER
@@ -885,7 +919,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shared_name = "shared-one-columm"
default_value = -1
table1 = lookup_ops.HashTable(
@@ -924,7 +958,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testInitializeTableWithNoFilename(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
with self.assertRaises(ValueError):
lookup_ops.HashTable(
@@ -934,7 +968,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value)
def testInitializeWithVocabSize(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -982,7 +1016,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testFeedVocabularyName(self):
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup_ops.HashTable(
lookup_ops.TextFileInitializer(
@@ -1008,7 +1042,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidFilenames(self):
vocabulary_file = self._createVocabFile("filename_shape.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
# Invalid data type
@@ -1031,7 +1065,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testIdToStringTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1048,7 +1082,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testStringToIdTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1065,7 +1099,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInt64ToIdTable(self):
vocab_file = self._createVocabFile(
"feat_to_id_3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup_ops.HashTable(
@@ -1090,7 +1124,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testStringIdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1110,7 +1144,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1132,7 +1166,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1151,7 +1185,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testStringIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -1172,7 +1206,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testInt32IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -1194,20 +1228,20 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testFloat64IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup_ops.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.float64)
def testBoolIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup_ops.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.bool)
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
@@ -1248,7 +1282,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1269,7 +1303,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1292,7 +1326,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
@@ -1328,7 +1362,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -1355,7 +1389,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -1383,7 +1417,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -1410,7 +1444,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -1451,7 +1485,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup_ops.StrongHashSpec([None, 2]))
def testIdTableWithHashBucketsNoInnerTable(self):
- with self.test_session():
+ with self.cached_session():
table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1)
self.assertIsNone(table.table_ref)
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 87fc715783..3ce0b74263 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -61,62 +61,62 @@ class AbsoluteDifferenceLossTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.absolute_difference(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = losses.absolute_difference(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = losses.absolute_difference(self._labels, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = losses.absolute_difference(self._labels, self._predictions,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant((1.2, 0.0), shape=(2, 1))
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(16.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(6.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = losses.absolute_difference(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@test_util.assert_no_new_pyobjects_executing_eagerly
@@ -134,12 +134,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.softmax_cross_entropy(labels, logits, weights=None)
def testAllCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
@@ -152,7 +152,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits)
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -162,7 +162,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -171,7 +171,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits,
constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -181,7 +181,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = constant_op.constant((1.2, 3.4, 5.6))
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -190,7 +190,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -199,12 +199,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = losses.softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -215,7 +215,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
losses.softmax_cross_entropy(labels, logits, weights=weights).eval()
def testSoftmaxLabelSmoothing(self):
- with self.test_session():
+ with self.cached_session():
# Softmax Cross Entropy Loss is:
# -\sum_i p_i \log q_i
# where for a softmax activation
@@ -242,12 +242,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.sparse_softmax_cross_entropy(labels, logits, weights=None)
def testAllCorrectInt32Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int32)
@@ -263,7 +263,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
losses.sparse_softmax_cross_entropy(labels, logits)
def testAllCorrectInt64Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]], dtype=dtypes.int64)
@@ -272,7 +272,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectNonColumnLabels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([0, 1, 2])
@@ -285,7 +285,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -295,7 +295,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -305,7 +305,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([2, 0, 1])
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -315,7 +315,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -324,7 +324,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits,
constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -334,7 +334,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(
labels, logits, constant_op.constant((weights,)))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -345,7 +345,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = array_ops.placeholder(dtypes.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
loss_val = sess.run(loss,
feed_dict={weights: ((1.2,), (3.4,), (5.6,))})
@@ -355,7 +355,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
logits = array_ops.placeholder(dtypes.float32)
labels = array_ops.placeholder(dtypes.int32)
weights = 1.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
loss_val = sess.run(loss,
feed_dict={
@@ -370,7 +370,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
logits = array_ops.placeholder(dtypes.float32, shape=(None, 3))
labels = array_ops.placeholder(dtypes.int32, shape=(None, 1))
weights = array_ops.placeholder(dtypes.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
loss_val = sess.run(loss,
feed_dict={
@@ -387,7 +387,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=(3, 1))
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -396,7 +396,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([[1.2], [3.4], [5.6]])
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -405,7 +405,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([0, 0, 0], shape=(3, 1))
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -414,12 +414,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 0, 0], shape=(3, 1))
- with self.test_session():
+ with self.cached_session():
loss = losses.sparse_softmax_cross_entropy(labels, logits, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -432,7 +432,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightSizeRaisesException(self):
"""The weight tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -445,7 +445,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelSizeRaisesException(self):
"""The label tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -458,7 +458,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightShapeRaisesException(self):
"""The weight tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -472,7 +472,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelShapeRaisesException(self):
"""The label tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -488,7 +488,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
class SigmoidCrossEntropyLossTest(test.TestCase):
def testAllCorrectSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -506,7 +506,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = losses.sigmoid_cross_entropy(labels, logits, weights)
self.assertEquals(logits.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 1)),
@@ -522,7 +522,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = losses.sigmoid_cross_entropy(labels, logits, weights)
self.assertEquals(logits.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 2)),
@@ -531,7 +531,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(0.313, loss, 3)
def testAllWrongSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -542,7 +542,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -562,7 +562,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertEquals(logits.dtype, loss.dtype)
self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testSigmoidFloat64(self):
@@ -577,7 +577,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = losses.sigmoid_cross_entropy(labels, logits)
self.assertEquals(logits.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(44.444, loss.eval(), 3)
def testSigmoidNoReduction(self):
@@ -590,7 +590,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
labels, logits, reduction=losses.Reduction.NONE)
self.assertEquals(logits.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose((
(0., 0., 0.),
(0., 100., 100.),
@@ -598,7 +598,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
), loss.eval(), 3)
def testSigmoidLabelSmoothingCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0]])
labels = constant_op.constant([[1, 0, 1]])
# Sigmoid cross entropy loss is:
@@ -621,7 +621,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), expected_value, 3)
def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
- with self.test_session():
+ with self.cached_session():
label_smoothing = 0.1
sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
sigmoid_labels = constant_op.constant([[1, 0, 1]])
@@ -656,33 +656,33 @@ class LogLossTest(test.TestCase):
self._labels = constant_op.constant(labels)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.log_loss(self._labels, self._labels, weights=None)
def testAllCorrectNoLossWeight(self):
loss = losses.log_loss(self._labels, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testAllCorrectNoLossWeightWithPlaceholder(self):
tf_predictions = array_ops.placeholder(
dtypes.float32, shape=self._np_labels.shape)
loss = losses.log_loss(self._labels, tf_predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
def testNonZeroLoss(self):
loss = losses.log_loss(self._labels, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = losses.log_loss(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -690,7 +690,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = losses.log_loss(self._labels, self._predictions,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -700,7 +700,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = losses.log_loss(self._labels, tf_predictions,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -710,7 +710,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = losses.log_loss(self._labels, tf_predictions,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -721,7 +721,7 @@ class LogLossTest(test.TestCase):
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
loss = losses.log_loss(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
@@ -730,7 +730,7 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = losses.log_loss(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
@@ -739,12 +739,12 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = losses.log_loss(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.log_loss(self._labels, self._predictions, weights)
@@ -757,7 +757,7 @@ class LogLossTest(test.TestCase):
self._predictions,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
@@ -771,7 +771,7 @@ class LogLossTest(test.TestCase):
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
@@ -784,7 +784,7 @@ class LogLossTest(test.TestCase):
self._predictions,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
@@ -795,35 +795,35 @@ class LogLossTest(test.TestCase):
tf_weights = constant_op.constant(weights, shape=(2, 3))
loss = losses.log_loss(self._labels, tf_predictions, tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
def testLossWithSampleSpecificWeightsAllZero(self):
tf_weights = array_ops.zeros(shape=(2, 3))
loss = losses.log_loss(self._labels, self._predictions, tf_weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class HingeLossTest(test.TestCase):
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-1.0], [2.1]])
labels = constant_op.constant([0.0, 1.0])
with self.assertRaises(ValueError):
_ = losses.hinge_loss(labels, logits).eval()
def testAllOutsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
loss = losses.hinge_loss(labels, logits)
self.assertAllClose(loss.eval(), 0.0, atol=1e-3)
def testSomeInsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
loss = losses.hinge_loss(labels, logits)
@@ -832,7 +832,7 @@ class HingeLossTest(test.TestCase):
self.assertAllClose(loss.eval(), 0.175, atol=1e-3)
def testSomeMisclassified(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
loss = losses.hinge_loss(labels, logits)
@@ -844,14 +844,14 @@ class HingeLossTest(test.TestCase):
class HuberLossTest(test.TestCase):
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
predictions = constant_op.constant([[-1.0], [2.1]])
labels = constant_op.constant([0.0, 1.0])
with self.assertRaises(ValueError):
_ = losses.huber_loss(labels, predictions).eval()
def testAllQuadratic(self):
- with self.test_session():
+ with self.cached_session():
predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
loss = losses.huber_loss(labels, predictions)
@@ -859,7 +859,7 @@ class HuberLossTest(test.TestCase):
0.5 * (0.25 + 0.16 + 1.0 + 0.25) / 4., atol=1e-5)
def testAllLinear(self):
- with self.test_session():
+ with self.cached_session():
predictions = constant_op.constant([1.5, -1.4, -1.0, 0.0])
labels = constant_op.constant([0.0, 1.0, 0.0, 1.5])
loss = losses.huber_loss(labels, predictions)
@@ -867,7 +867,7 @@ class HuberLossTest(test.TestCase):
(1.5 + 2.4 + 1.0 + 1.5) / 4. - 0.5, atol=1e-5)
def testMixedQuadraticLinear(self):
- with self.test_session():
+ with self.cached_session():
predictions = constant_op.constant([[1.5, -1.4, -1.0, 0.0],
[1.5, -1.4, -1.0, 0.0]])
labels = constant_op.constant([[1.0, -1.0, 0.0, 0.5],
@@ -879,7 +879,7 @@ class HuberLossTest(test.TestCase):
self.assertAllClose(loss.eval(), expected_loss, atol=1e-5)
def testAllQuadraticDelta(self):
- with self.test_session():
+ with self.cached_session():
delta = 0.5
predictions = constant_op.constant([1.5, -1.4, -0.5, 0.0])
labels = constant_op.constant([1.0, -1.0, 0.0, 0.5])
@@ -894,7 +894,7 @@ class HuberLossTest(test.TestCase):
expected = delta * np.array([1.5, 2.4, 1.0, 1.5]).mean()
expected -= 0.5 * delta**2
loss = losses.huber_loss(labels, predictions, delta=delta)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected, loss.eval(), atol=1e-5)
@@ -906,13 +906,13 @@ class MeanSquaredErrorTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.mean_squared_error(
self._predictions, self._predictions, weights=None)
def testScalar(self):
- with self.test_session():
+ with self.cached_session():
self.assertEqual(
0.0,
losses.mean_squared_error(predictions=constant_op.constant(0),
@@ -920,55 +920,55 @@ class MeanSquaredErrorTest(test.TestCase):
def testAllCorrectNoLossWeight(self):
loss = losses.mean_squared_error(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = losses.mean_squared_error(self._labels, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = losses.mean_squared_error(self._labels, self._predictions,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=(2, 1))
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(18.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = losses.mean_squared_error(self._labels, self._predictions, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -994,7 +994,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
self._expected_losses = np.divide(total, 3.0)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
@@ -1003,7 +1003,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
def _test_valid_weights(
self, labels, predictions, expected_loss, weights=1.0):
- with self.test_session():
+ with self.cached_session():
static_inputs_op = losses.mean_pairwise_squared_error(
predictions=predictions, labels=labels, weights=weights)
self.assertAlmostEqual(expected_loss, static_inputs_op.eval(), places=3)
@@ -1054,7 +1054,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for grad, _ in gradients_to_variables:
np_grad = sess.run(grad)
@@ -1073,7 +1073,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -1122,7 +1122,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
predictions=predictions_placeholder,
labels=labels_placeholder,
weights=weights_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
dynamic_inputs_op.eval(feed_dict={
predictions_placeholder: predictions,
@@ -1191,7 +1191,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
labels=array_ops.concat([labels0, labels1], 0),
predictions=array_ops.concat([predictions0, predictions1], 0))
- with self.test_session() as session:
+ with self.cached_session() as session:
loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
self.assertTrue(loss0 > 0)
@@ -1216,7 +1216,7 @@ class CosineDistanceLossTest(test.TestCase):
[0, 0, 1], [0, 1, 0]]).reshape((3, 2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
losses.cosine_distance(
predictions=constant_op.constant(self._labels),
@@ -1229,7 +1229,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 5)
def testPartiallyCorrectWithIntegerValues(self):
@@ -1237,7 +1237,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1, loss.eval(), 5)
def testPartiallyCorrectFloatingPointValues(self):
@@ -1255,7 +1255,7 @@ class CosineDistanceLossTest(test.TestCase):
labels, shape=(3, 1, 3), dtype=dtypes.float32)
loss = losses.cosine_distance(tf_labels, tf_preds, dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1.0, loss.eval(), 5)
def testSampleSpecificWeights(self):
@@ -1264,7 +1264,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=np.asarray((1, 0, 0)).reshape((3, 1, 1)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, loss.eval())
def testMeasurementSpecificWeights(self):
@@ -1274,7 +1274,7 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2, 1)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(3.0 / 4.0, loss.eval())
def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
@@ -1286,7 +1286,7 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2, 1)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
self.assertEqual(3.0 / 4.0, loss)
@@ -1296,7 +1296,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3, 1, 1)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
@@ -1305,7 +1305,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3, 2, 1)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
@@ -1411,7 +1411,7 @@ class ComputeWeightedLossTest(test.TestCase):
weighted_loss = losses.compute_weighted_loss(
self._raw_losses, weights=weight)
self.assertEqual(1, len(util.get_losses()))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
np.mean(weight * self._raw_losses), weighted_loss.eval())
@@ -1429,7 +1429,7 @@ class ComputeWeightedLossTest(test.TestCase):
weighted_loss = losses.compute_weighted_loss(
self._raw_losses, weights=weights_placeholder)
self.assertEqual(1, len(util.get_losses()))
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
weighted_loss.eval(feed_dict={weights_placeholder: weights})
@@ -1452,7 +1452,7 @@ class ComputeWeightedLossTest(test.TestCase):
weighted_loss = losses.compute_weighted_loss(
raw_losses, weights=weights_placeholder)
self.assertEqual(1, len(util.get_losses()))
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
weighted_loss.eval(feed_dict={weights_placeholder: weights})
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
index dc3ea38671..f71857a3cb 100644
--- a/tensorflow/python/kernel_tests/manip_ops_test.py
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -42,12 +42,12 @@ class RollTest(test_util.TensorFlowTestCase):
def _testRoll(self, np_input, shift, axis):
expected_roll = np.roll(np_input, shift, axis)
- with self.test_session():
+ with self.cached_session():
roll = manip_ops.roll(np_input, shift, axis)
self.assertAllEqual(roll.eval(), expected_roll)
def _testGradient(self, np_input, shift, axis):
- with self.test_session():
+ with self.cached_session():
inx = constant_op.constant(np_input.tolist())
xs = list(np_input.shape)
y = manip_ops.roll(inx, shift, axis)
@@ -94,7 +94,7 @@ class RollTest(test_util.TensorFlowTestCase):
self._testAll(np.random.randint(-100, 100, (5)).astype(np.int32), 3, -1)
self._testAll(np.random.randint(-100, 100, (4, 4)).astype(np.int32), 3, -2)
# Make sure negative axis should be 0 <= axis + dims < dims
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(np.random.randint(-100, 100, (4, 4)).astype(np.int32),
@@ -111,7 +111,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = array_ops.placeholder(dtype=dtypes.int32)
shift = 1
axis = 0
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"input must be 1-D or higher"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={tensor: 7})
@@ -127,7 +127,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = 1
axis = array_ops.placeholder(dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"axis must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={axis: [[0, 1]]})
@@ -143,7 +143,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = 1
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift must be a scalar or a 1-D vector"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [[0, 1]]})
@@ -158,7 +158,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [[1, 2], [3, 4]]
shift = array_ops.placeholder(dtype=dtypes.int32)
axis = [0, 1]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"shift and axis must have the same size"):
manip_ops.roll(tensor, shift, axis).eval(feed_dict={shift: [1]})
@@ -167,7 +167,7 @@ class RollTest(test_util.TensorFlowTestCase):
tensor = [1, 2]
shift = 1
axis = 1
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"is out of range"):
manip_ops.roll(tensor, shift, axis).eval()
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index b167278984..309da8f184 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -206,7 +206,7 @@ class MatMulInfixOperatorTest(test_lib.TestCase):
b = ops.convert_to_tensor([[40.0, 50.0], [60.0, 70.0], [80.0, 90.0]])
c = infix_matmul(a, b)
d = math_ops.matmul(a, b)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(c.eval(), d.eval())
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index f41967ff98..720ba806e9 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -114,7 +114,7 @@ class InverseOpTest(test.TestCase):
def testNotInvertible(self):
# The input should be invertible.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Input is not invertible."):
# All rows of the matrix below add to zero.
tensor3 = constant_op.constant([[1., 0., -1.], [-1., 1., 0.],
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index 33288392c0..dd01ba11af 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -143,7 +143,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
def testNonSquareMatrix(self):
# A non-square matrix should cause an error.
matrix = np.array([[1., 2., 3.], [3., 4., 5.]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
self._verifySolve(matrix, matrix)
with self.assertRaises(ValueError):
@@ -154,7 +154,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# right-hand sides.
matrix = np.array([[1., 0.], [0., 1.]])
rhs = np.array([[1., 0.]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
self._verifySolve(matrix, rhs)
with self.assertRaises(ValueError):
@@ -164,7 +164,7 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# The input should be invertible.
# The matrix is singular because it has a zero on the diagonal.
singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Input matrix is not invertible."):
self._verifySolve(singular_matrix, singular_matrix)
with self.assertRaisesOpError("Input matrix is not invertible."):
diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py
index 55653489af..5dcdb9e420 100644
--- a/tensorflow/python/kernel_tests/metrics_test.py
+++ b/tensorflow/python/kernel_tests/metrics_test.py
@@ -192,7 +192,7 @@ class MeanTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -209,7 +209,7 @@ class MeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -253,7 +253,7 @@ class MeanTest(test.TestCase):
metrics.mean(values, weights=np.ones((3, 2, 4, 1))),
metrics.mean(values, weights=np.ones((3, 2, 4, 1, 1))),)
expected = np.mean(values)
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
for mean_result in mean_results:
mean, update_op = mean_result
@@ -266,7 +266,7 @@ class MeanTest(test.TestCase):
np.sum(np.multiply(weights, np.ones_like(values)))
)
mean, update_op = metrics.mean(values, weights=weights)
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
self.assertAlmostEqual(expected, update_op.eval(), places=5)
self.assertAlmostEqual(expected, mean.eval(), places=5)
@@ -330,7 +330,7 @@ class MeanTest(test.TestCase):
# Dynamic shapes.
with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
- with self.test_session():
+ with self.cached_session():
_, update_op = metrics.mean(values_placeholder, invalid_weight)
variables.local_variables_initializer().run()
update_op.eval(feed_dict={values_placeholder: values})
@@ -359,7 +359,7 @@ class MeanTensorTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -376,7 +376,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
def testMultiDimensional(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
_enqueue_vector(
@@ -397,7 +397,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -418,7 +418,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
def testBinaryWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -445,7 +445,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -472,7 +472,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[0.8, 3.52]], sess.run(mean), 5)
def testWeighted2d_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -499,7 +499,7 @@ class MeanTensorTest(test.TestCase):
self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
def testWeighted2d_2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -575,7 +575,7 @@ class AccuracyTest(test.TestCase):
(10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
accuracy, update_op = metrics.accuracy(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -588,7 +588,7 @@ class AccuracyTest(test.TestCase):
self.assertEqual(initial_accuracy, accuracy.eval())
def testMultipleUpdates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -618,7 +618,7 @@ class AccuracyTest(test.TestCase):
def testEffectivelyEquivalentSizes(self):
predictions = array_ops.ones((40, 1))
labels = array_ops.ones((40,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.accuracy(labels, predictions)
sess.run(variables.local_variables_initializer())
@@ -628,7 +628,7 @@ class AccuracyTest(test.TestCase):
def testEffectivelyEquivalentSizesWithScalarWeight(self):
predictions = array_ops.ones((40, 1))
labels = array_ops.ones((40,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0)
sess.run(variables.local_variables_initializer())
@@ -642,7 +642,7 @@ class AccuracyTest(test.TestCase):
weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
1) # shape 3, 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.accuracy(labels, predictions, weights)
sess.run(variables.local_variables_initializer())
@@ -662,7 +662,7 @@ class AccuracyTest(test.TestCase):
dtype=dtypes_lib.int32, name='weights')
feed_dict = {weights_placeholder: weights}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.accuracy(labels, predictions,
weights_placeholder)
@@ -674,7 +674,7 @@ class AccuracyTest(test.TestCase):
self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
def testMultipleUpdatesWithWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -746,7 +746,7 @@ class PrecisionTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -765,7 +765,7 @@ class PrecisionTest(test.TestCase):
labels = constant_op.constant(inputs)
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op))
self.assertAlmostEqual(1, precision.eval())
@@ -778,7 +778,7 @@ class PrecisionTest(test.TestCase):
constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, precision.eval())
@@ -789,7 +789,7 @@ class PrecisionTest(test.TestCase):
precision, update_op = metrics.precision(
labels, predictions, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -806,7 +806,7 @@ class PrecisionTest(test.TestCase):
}
precision, update_op = metrics.precision(labels, predictions, weights=2)
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 2.0
weighted_positives = (2.0 + 2.0) + (2.0 + 2.0)
@@ -826,7 +826,7 @@ class PrecisionTest(test.TestCase):
precision, update_op = metrics.precision(
labels, predictions, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -844,7 +844,7 @@ class PrecisionTest(test.TestCase):
predictions,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -864,7 +864,7 @@ class PrecisionTest(test.TestCase):
predictions,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -881,7 +881,7 @@ class PrecisionTest(test.TestCase):
labels = constant_op.constant(1 - inputs)
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0, precision.eval())
@@ -891,7 +891,7 @@ class PrecisionTest(test.TestCase):
labels = constant_op.constant([0, 0, 0, 0])
precision, update_op = metrics.precision(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0.0, precision.eval())
@@ -933,7 +933,7 @@ class RecallTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -952,7 +952,7 @@ class RecallTest(test.TestCase):
labels = constant_op.constant(np_inputs)
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, recall.eval())
@@ -965,7 +965,7 @@ class RecallTest(test.TestCase):
constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, recall.eval())
@@ -976,7 +976,7 @@ class RecallTest(test.TestCase):
weights = constant_op.constant([[2], [5]])
recall, update_op = metrics.recall(labels, predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -990,7 +990,7 @@ class RecallTest(test.TestCase):
weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])
recall, update_op = metrics.recall(labels, predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 3.0 + 1.0
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
@@ -1005,7 +1005,7 @@ class RecallTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1015,7 +1015,7 @@ class RecallTest(test.TestCase):
labels = array_ops.zeros((1, 4))
recall, update_op = metrics.recall(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1055,7 +1055,7 @@ class AUCTest(test.TestCase):
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
auc, update_op = metrics.auc(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1073,7 +1073,7 @@ class AUCTest(test.TestCase):
def allCorrectAsExpected(self, curve):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
auc, update_op = metrics.auc(labels, predictions, curve=curve)
@@ -1084,7 +1084,7 @@ class AUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect_multipleLabelDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for label_dtype in (
dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
predictions = constant_op.constant(
@@ -1099,7 +1099,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval())
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1112,7 +1112,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval(), 5)
def testWeighted2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1127,7 +1127,7 @@ class AUCTest(test.TestCase):
# Regarding the AUC-PR tests: note that the preferred method when
# calculating AUC-PR is summation_method='careful_interpolation'.
def testCorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1141,7 +1141,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
def testCorrectAnotherAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
@@ -1157,7 +1157,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
def testThirdCorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -1173,7 +1173,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(expected, auc.eval(), delta=1e-3)
def testIncorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1186,7 +1186,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
def testAnotherIncorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
@@ -1201,7 +1201,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
def testThirdIncorrectAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -1218,7 +1218,7 @@ class AUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.auc(labels, predictions)
@@ -1229,7 +1229,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
auc, update_op = metrics.auc(labels, predictions)
@@ -1240,7 +1240,7 @@ class AUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
labels = array_ops.ones([4])
auc, update_op = metrics.auc(labels, predictions, curve='PR')
@@ -1301,7 +1301,7 @@ class AUCTest(test.TestCase):
scale=1.0, size=num_samples)):
expected_auc = self.np_auc(predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
tf_labels = _enqueue_as_batches(labels, enqueue_ops)
@@ -1370,7 +1370,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1390,7 +1390,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -1405,7 +1405,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, sensitivity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op))
self.assertAlmostEqual(1.0, specificity.eval())
@@ -1420,7 +1420,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -1439,7 +1439,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -1457,7 +1457,7 @@ class SpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.specificity_at_sensitivity(
labels, predictions, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
@@ -1507,7 +1507,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
sensitivity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1527,7 +1527,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -1542,7 +1542,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, specificity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, specificity.eval())
@@ -1557,7 +1557,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
self.assertAlmostEqual(0.6, specificity.eval())
@@ -1576,7 +1576,7 @@ class SensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.sensitivity_at_specificity(
labels, predictions, weights=weights, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.675, sess.run(update_op))
self.assertAlmostEqual(0.675, specificity.eval())
@@ -1638,7 +1638,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
thresholds)
rec, rec_op = metrics.recall_at_thresholds(labels, predictions, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates, then verify idempotency.
@@ -1654,7 +1654,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -1670,7 +1670,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertEqual(1, rec.eval())
def testSomeCorrect_multipleLabelDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for label_dtype in (
dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
predictions = constant_op.constant(
@@ -1692,7 +1692,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -1708,7 +1708,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0, rec.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -1738,7 +1738,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -1768,7 +1768,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -1792,7 +1792,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -1842,7 +1842,7 @@ class PrecisionRecallThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -2801,7 +2801,7 @@ class MeanAbsoluteErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.mean_absolute_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2822,7 +2822,7 @@ class MeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.mean_absolute_error(labels, predictions, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(3, sess.run(update_op))
self.assertEqual(3, error.eval())
@@ -2866,7 +2866,7 @@ class MeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.mean_relative_error(labels, predictions,
normalizer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2891,7 +2891,7 @@ class MeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.mean_relative_error(
labels, predictions, normalizer=labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(expected_error, sess.run(update_op))
self.assertEqual(expected_error, error.eval())
@@ -2907,7 +2907,7 @@ class MeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.mean_relative_error(
labels, predictions, normalizer=array_ops.zeros_like(labels))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.0, sess.run(update_op))
self.assertEqual(0.0, error.eval())
@@ -2945,7 +2945,7 @@ class MeanSquaredErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.mean_squared_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2963,7 +2963,7 @@ class MeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.mean_squared_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -2976,7 +2976,7 @@ class MeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.mean_squared_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(6, sess.run(update_op))
self.assertEqual(6, error.eval())
@@ -2990,13 +2990,13 @@ class MeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.mean_squared_error(labels, predictions, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(13, sess.run(update_op))
self.assertEqual(13, error.eval())
def testMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3020,7 +3020,7 @@ class MeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
def testMetricsComputedConcurrently(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates one set of predictions.
preds_queue0 = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3063,7 +3063,7 @@ class MeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(79.0 / 6, mse1, 5)
def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -3122,7 +3122,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.root_mean_squared_error(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3135,7 +3135,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(initial_error, error.eval())
def testSingleUpdateZeroError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
0.0, shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
@@ -3148,7 +3148,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(0, rmse.eval())
def testSingleUpdateWithError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -3161,7 +3161,7 @@ class RootMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -3220,7 +3220,7 @@ class MeanCosineDistanceTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3242,7 +3242,7 @@ class MeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -3258,7 +3258,7 @@ class MeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op), 5)
self.assertAlmostEqual(1, error.eval(), 5)
@@ -3279,7 +3279,7 @@ class MeanCosineDistanceTest(test.TestCase):
np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op), 5)
self.assertAlmostEqual(1.0, error.eval(), 5)
@@ -3298,7 +3298,7 @@ class MeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.mean_cosine_distance(
labels, predictions, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -3317,7 +3317,7 @@ class MeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.mean_cosine_distance(
labels, predictions, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.5, update_op.eval())
self.assertEqual(1.5, error.eval())
@@ -3352,7 +3352,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
@@ -3369,7 +3369,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertAlmostEqual(0.0, pcnt2, 5)
def testSomePresentOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant(
@@ -3445,7 +3445,7 @@ class MeanIOUTest(test.TestCase):
mean_iou, update_op = metrics.mean_iou(
labels, predictions, num_classes=num_classes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3459,7 +3459,7 @@ class MeanIOUTest(test.TestCase):
def testMultipleUpdates(self):
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3490,7 +3490,7 @@ class MeanIOUTest(test.TestCase):
def testMultipleUpdatesWithWeights(self):
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3538,7 +3538,7 @@ class MeanIOUTest(test.TestCase):
# one class, and thus there is one row and one column with
# zero entries in the confusion matrix.
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
# There is no prediction for class 2.
preds_queue = data_flow_ops.FIFOQueue(
@@ -3585,7 +3585,7 @@ class MeanIOUTest(test.TestCase):
],
0)
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
confusion_matrix = update_op.eval()
@@ -3597,7 +3597,7 @@ class MeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
num_classes = 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertEqual(40, update_op.eval()[0])
@@ -3607,7 +3607,7 @@ class MeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.ones([40])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[0, 0], [40, 0]], update_op.eval())
@@ -3637,7 +3637,7 @@ class MeanIOUTest(test.TestCase):
0, shape=[1])
],
0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(
labels, predictions, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
@@ -3657,7 +3657,7 @@ class MeanIOUTest(test.TestCase):
[[0, 0, 2, 1, 1, 1],
[1, 1, 2, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval())
@@ -3669,7 +3669,7 @@ class MeanIOUTest(test.TestCase):
labels = constant_op.constant([0])
predictions = constant_op.constant([0])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[1, 0], [0, 0]], update_op.eval())
@@ -3687,7 +3687,7 @@ class MeanIOUTest(test.TestCase):
[[0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval())
@@ -3751,7 +3751,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes=num_classes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3764,7 +3764,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
self.assertEqual(initial_mean_accuracy, mean_accuracy.eval())
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3796,7 +3796,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
def testMultipleUpdatesWithWeights(self):
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -3844,7 +3844,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
# one class, and thus there is one row and one column with
# zero entries in the confusion matrix.
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
# There is no prediction for class 2.
preds_queue = data_flow_ops.FIFOQueue(
@@ -3880,7 +3880,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
num_classes = 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
@@ -3891,7 +3891,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.ones([40])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes)
sess.run(variables.local_variables_initializer())
@@ -3910,7 +3910,7 @@ class MeanPerClassAccuracyTest(test.TestCase):
constant_op.constant(0, shape=[1]), constant_op.constant(1, shape=[8]),
constant_op.constant(0, shape=[1])
], 0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
mean_accuracy, update_op = metrics.mean_per_class_accuracy(
labels, predictions, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
@@ -3944,7 +3944,7 @@ class FalseNegativesTest(test.TestCase):
tn, tn_update_op = metrics.false_negatives(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(3., tn_update_op.eval())
@@ -3963,7 +3963,7 @@ class FalseNegativesTest(test.TestCase):
tn, tn_update_op = metrics.false_negatives(
labels=labels, predictions=predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(5., tn_update_op.eval())
@@ -3993,7 +3993,7 @@ class FalseNegativesAtThresholdsTest(test.TestCase):
fn, fn_update_op = metrics.false_negatives_at_thresholds(
predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fn.eval())
self.assertAllEqual((0, 2, 3), fn_update_op.eval())
@@ -4012,7 +4012,7 @@ class FalseNegativesAtThresholdsTest(test.TestCase):
weights=((3.0,), (5.0,), (7.0,)),
thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
@@ -4043,7 +4043,7 @@ class FalsePositivesTest(test.TestCase):
tn, tn_update_op = metrics.false_positives(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(7., tn_update_op.eval())
@@ -4062,7 +4062,7 @@ class FalsePositivesTest(test.TestCase):
tn, tn_update_op = metrics.false_positives(
labels=labels, predictions=predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(14., tn_update_op.eval())
@@ -4092,7 +4092,7 @@ class FalsePositivesAtThresholdsTest(test.TestCase):
fp, fp_update_op = metrics.false_positives_at_thresholds(
predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fp.eval())
self.assertAllEqual((7, 4, 2), fp_update_op.eval())
@@ -4113,7 +4113,7 @@ class FalsePositivesAtThresholdsTest(test.TestCase):
(19.0, 23.0, 29.0, 31.0)),
thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
@@ -4144,7 +4144,7 @@ class TrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.true_negatives(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(3., tn_update_op.eval())
@@ -4163,7 +4163,7 @@ class TrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.true_negatives(
labels=labels, predictions=predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(4., tn_update_op.eval())
@@ -4193,7 +4193,7 @@ class TrueNegativesAtThresholdsTest(test.TestCase):
tn, tn_update_op = metrics.true_negatives_at_thresholds(
predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tn.eval())
self.assertAllEqual((2, 5, 7), tn_update_op.eval())
@@ -4212,7 +4212,7 @@ class TrueNegativesAtThresholdsTest(test.TestCase):
weights=((0.0, 2.0, 3.0, 5.0),),
thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
@@ -4243,7 +4243,7 @@ class TruePositivesTest(test.TestCase):
tn, tn_update_op = metrics.true_positives(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(7., tn_update_op.eval())
@@ -4262,7 +4262,7 @@ class TruePositivesTest(test.TestCase):
tn, tn_update_op = metrics.true_positives(
labels=labels, predictions=predictions, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllClose(0., tn.eval())
self.assertAllClose(12., tn_update_op.eval())
@@ -4292,7 +4292,7 @@ class TruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.true_positives_at_thresholds(
predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tp.eval())
self.assertAllEqual((3, 1, 0), tp_update_op.eval())
@@ -4309,7 +4309,7 @@ class TruePositivesAtThresholdsTest(test.TestCase):
predictions=predictions, labels=labels, weights=37.0,
thresholds=[0.15, 0.5, 0.85])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py
index 89ada8430e..6cc70f7c89 100644
--- a/tensorflow/python/kernel_tests/numerics_test.py
+++ b/tensorflow/python/kernel_tests/numerics_test.py
@@ -66,7 +66,7 @@ class VerifyTensorAllFiniteTest(test.TestCase):
class NumericsTest(test.TestCase):
def testInf(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
t1 = constant_op.constant(1.0)
t2 = constant_op.constant(0.0)
a = math_ops.div(t1, t2)
@@ -76,7 +76,7 @@ class NumericsTest(test.TestCase):
a.eval()
def testNaN(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
t1 = constant_op.constant(0.0)
t2 = constant_op.constant(0.0)
a = math_ops.div(t1, t2)
@@ -86,7 +86,7 @@ class NumericsTest(test.TestCase):
a.eval()
def testBoth(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
t1 = constant_op.constant([1.0, 0.0])
t2 = constant_op.constant([0.0, 0.0])
a = math_ops.div(t1, t2)
@@ -96,7 +96,7 @@ class NumericsTest(test.TestCase):
a.eval()
def testPassThrough(self):
- with self.test_session(graph=ops.Graph()):
+ with self.session(graph=ops.Graph()):
t1 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
checked = array_ops.check_numerics(t1, message="pass through test")
value = checked.eval()
diff --git a/tensorflow/python/kernel_tests/pad_op_test.py b/tensorflow/python/kernel_tests/pad_op_test.py
index 944de217a1..e415d7879e 100644
--- a/tensorflow/python/kernel_tests/pad_op_test.py
+++ b/tensorflow/python/kernel_tests/pad_op_test.py
@@ -188,7 +188,7 @@ class PadOpTest(test.TestCase):
mode="SYMMETRIC").eval()
def testInvalid(self):
- with self.test_session():
+ with self.cached_session():
x = [[1, 2, 3], [4, 5, 6]]
with self.assertRaisesRegexp(ValueError, "Unknown padding mode"):
array_ops.pad(x, [[1, 0], [2, 1]], mode="weird").eval()
diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
index d8c3f9823c..95f3dcceea 100644
--- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
@@ -95,13 +95,13 @@ class PaddingFIFOQueueTest(test.TestCase):
""", q.queue_ref.op.node_def)
def testEnqueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()
def testEnqueueWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=((3, 2),))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
@@ -111,14 +111,14 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(1, q.size().eval())
def testEnqueueManyWithShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(
10, [dtypes_lib.int32, dtypes_lib.int32], shapes=[(), (2,)])
q.enqueue_many([[1, 2, 3, 4], [[1, 1], [2, 2], [3, 3], [4, 4]]]).run()
self.assertEqual(4, q.size().eval())
def testParallelEnqueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -144,7 +144,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testParallelDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -168,7 +168,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, results)
def testDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -182,7 +182,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([elems[i]], vals)
def testEnqueueAndBlockingDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(3, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
@@ -212,7 +212,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([elem], result)
def testMultiEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10,
(dtypes_lib.int32, dtypes_lib.float32),
((), ()))
@@ -230,12 +230,12 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([y], y_val)
def testQueueSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
self.assertEqual([0], q.size().eval())
def testQueueSizeAfterEnqueueAndDequeue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
@@ -248,7 +248,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, size.eval())
def testEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -261,7 +261,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([elems[i % 4]], vals)
def testEmptyEnqueueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, (
(None, None),))
empty_t = constant_op.constant(
@@ -274,7 +274,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([0], size_t.eval())
def testEmptyDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, shapes=((),))
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue_many(0)
@@ -284,7 +284,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueManyWithDynamicShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=((None,),))
enqueue_op = q.enqueue(([10.0],))
@@ -295,7 +295,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testEmptyDequeueUpToWithDynamicShape(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, shapes=((None,),))
enqueue_op = q.enqueue(([10.0],))
@@ -306,7 +306,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual([], dequeued_t.eval().tolist())
def testConstructPaddingFIFOQueueWithNoShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError,
r"When providing partial shapes, a list of shapes must be provided."):
@@ -314,7 +314,7 @@ class PaddingFIFOQueueTest(test.TestCase):
None).queue_ref.eval()
def testMultiEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10,
(dtypes_lib.float32, dtypes_lib.int32),
((), (2,)))
@@ -332,7 +332,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(int_elems[i % 4], int_val)
def testMultiEnqueueManyWithPartiallyKnownShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (None,)))
float_elems = [10.0, 20.0, 30.0, 40.0]
@@ -349,7 +349,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(int_elems[i % 4], int_val)
def testDequeueMany(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -361,7 +361,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems[4:8], dequeued_t.eval())
def testDequeueUpToNoBlocking(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_op = q.enqueue_many((elems,))
@@ -373,7 +373,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems[4:8], dequeued_t.eval())
def testMultiDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (2,)))
float_elems = [
@@ -404,7 +404,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(int_val.shape, dequeued_single_t[1].get_shape())
def testMultiDequeueManyWithPartiallyKnownShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.float32, dtypes_lib.int32), shapes=((), (None,)))
float_elems = [
@@ -443,7 +443,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_single_t[1].get_shape()))
def testMultiDequeueManyWithPartiallyKnownShapesAndVariableSizeInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.string, dtypes_lib.int32),
shapes=((None,), (1, None)))
@@ -484,7 +484,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_single_t[1].get_shape()))
def testMultiDequeueUpToPartiallyKnownShapesAndVariableInputNoBlocking(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(
10, (dtypes_lib.string, dtypes_lib.int32),
shapes=((None,), (1, None)))
@@ -525,7 +525,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_single_t[1].get_shape()))
def testHighDimension(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, ((4, 4, 4, 4),))
elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
enqueue_op = q.enqueue_many((elems,))
@@ -535,7 +535,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(dequeued_t.eval(), elems)
def testPartiallyKnownHighDimension(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, (
(4, None, 4, None),))
elems = np.array([[[[[x] * 4] * 4] * 4] * 4 for x in range(10)], np.int32)
@@ -592,7 +592,7 @@ class PaddingFIFOQueueTest(test.TestCase):
array_ops.placeholder(dtypes_lib.int32)))
def testEnqueueWrongPartiallyKnownShapeAtRuntime(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First dimension of second component is unknown, second
# dimension must be 3.
q = data_flow_ops.PaddingFIFOQueue(10,
@@ -607,7 +607,7 @@ class PaddingFIFOQueueTest(test.TestCase):
feed_dict={elems_bad: np.array([1] * 12).reshape((3, 4))})
def testEnqueueDequeueManyWrongPartiallyKnownShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# First dimension of second component is unknown, second
# dimension must be 3.
q = data_flow_ops.PaddingFIFOQueue(10,
@@ -625,7 +625,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_t.eval()
def testParallelEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(100)]
enqueue_op = q.enqueue_many((elems,))
@@ -644,7 +644,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(dequeued_t.eval(), elems * 10)
def testParallelDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(1000)]
enqueue_op = q.enqueue_many((elems,))
@@ -666,7 +666,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(1000, dtypes_lib.float32, shapes=((),))
elems = [10.0 * x for x in range(1000)]
enqueue_op = q.enqueue_many((elems,))
@@ -690,7 +690,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertItemsEqual(elems, dequeued_elems)
def testParallelEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(50, dtypes_lib.float32, shapes=((),))
initial_elements = [10.0] * 49
q.enqueue_many((initial_elements,)).run()
@@ -723,7 +723,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertTrue(elem in (10.0, 20.0))
def testMixtureOfEnqueueAndEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, shapes=((),))
enqueue_placeholder = array_ops.placeholder(dtypes_lib.int32, shape=())
enqueue_op = q.enqueue((enqueue_placeholder,))
@@ -759,7 +759,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testMixtureOfDequeueAndDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.int32, shapes=((),))
enqueue_op = q.enqueue_many((np.arange(250, dtype=np.int32),))
dequeued_t = q.dequeue()
@@ -793,7 +793,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -820,7 +820,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
def testBlockingDequeueUpTo(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -847,7 +847,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elems, dequeued_elems)
def testDequeueManyWithTensorParameter(self):
- with self.test_session():
+ with self.cached_session():
# Define a first queue that contains integer counts.
dequeue_counts = [random.randint(1, 10) for _ in range(100)]
count_q = data_flow_ops.PaddingFIFOQueue(100, dtypes_lib.int32, ((),))
@@ -872,7 +872,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(elems, dequeued_elems)
def testDequeueFromClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -890,7 +890,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeued_t.eval()
def testBlockingDequeueFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -916,7 +916,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testDequeueUpToFromClosedQueueReturnsRemainder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -938,7 +938,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close()
dequeued_t = q.dequeue()
@@ -958,7 +958,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueManyFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -983,7 +983,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueManyButNotAllFromClosedQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1008,7 +1008,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueManyLargerThanCapacityWithConcurrentDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1045,7 +1045,7 @@ class PaddingFIFOQueueTest(test.TestCase):
close_thread.join()
def testClosedBlockingDequeueManyRestoresPartialBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, (dtypes_lib.float32,
dtypes_lib.float32), ((), ()))
elems_a = [1.0, 2.0, 3.0]
@@ -1078,7 +1078,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingDequeueManyFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close()
dequeued_t = q.dequeue_many(4)
@@ -1098,7 +1098,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testBlockingDequeueUpToFromClosedEmptyQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
close_op = q.close()
dequeued_t = q.dequeue_up_to(4)
@@ -1118,7 +1118,7 @@ class PaddingFIFOQueueTest(test.TestCase):
dequeue_thread.join()
def testEnqueueToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
enqueue_op = q.enqueue((10.0,))
close_op = q.close()
@@ -1131,7 +1131,7 @@ class PaddingFIFOQueueTest(test.TestCase):
enqueue_op.run()
def testEnqueueManyToClosedQueue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1145,7 +1145,7 @@ class PaddingFIFOQueueTest(test.TestCase):
enqueue_op.run()
def testBlockingEnqueueToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1168,7 +1168,7 @@ class PaddingFIFOQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueManyToFullQueue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1195,7 +1195,7 @@ class PaddingFIFOQueueTest(test.TestCase):
thread.join()
def testBlockingEnqueueBeforeClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0, 40.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1232,7 +1232,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(0, q.size().eval())
def testBlockingEnqueueManyBeforeClose(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(4, dtypes_lib.float32, ((),))
elems = [10.0, 20.0, 30.0]
enqueue_op = q.enqueue_many((elems,))
@@ -1265,7 +1265,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(elem, dequeued_t.eval())
def testDoesNotLoseValue(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PaddingFIFOQueue(1, dtypes_lib.float32, ((),))
enqueue_op = q.enqueue((10.0,))
size_t = q.size()
@@ -1275,7 +1275,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(size_t.eval(), [1])
def testSharedQueueSameSession(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.PaddingFIFOQueue(
1, dtypes_lib.float32, ((),), shared_name="shared_queue")
q1.enqueue((10.0,)).run()
@@ -1305,7 +1305,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(q2_size_t.eval(), [0])
def testIncompatibleSharedQueueErrors(self):
- with self.test_session():
+ with self.cached_session():
q_a_1 = data_flow_ops.PaddingFIFOQueue(
10, dtypes_lib.float32, ((),), shared_name="q_a")
q_a_2 = data_flow_ops.PaddingFIFOQueue(
@@ -1356,7 +1356,7 @@ class PaddingFIFOQueueTest(test.TestCase):
q_f_2.queue_ref.op.run()
def testSelectQueue(self):
- with self.test_session():
+ with self.cached_session():
num_queues = 10
qlist = list()
for _ in xrange(num_queues):
@@ -1370,7 +1370,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertEqual(q.dequeue().eval(), 10.0)
def testSelectQueueOutOfRange(self):
- with self.test_session():
+ with self.cached_session():
q1 = data_flow_ops.PaddingFIFOQueue(10, dtypes_lib.float32, ((),))
q2 = data_flow_ops.PaddingFIFOQueue(15, dtypes_lib.float32, ((),))
enq_q = data_flow_ops.PaddingFIFOQueue.from_list(3, [q1, q2])
@@ -1394,7 +1394,7 @@ class PaddingFIFOQueueTest(test.TestCase):
sess.run(enqueue_many_op)
def testResetOfBlockingOperation(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q_empty = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.float32, ((),))
dequeue_op = q_empty.dequeue()
dequeue_many_op = q_empty.dequeue_many(1)
@@ -1422,7 +1422,7 @@ class PaddingFIFOQueueTest(test.TestCase):
t.join()
def testBigEnqueueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(5, dtypes_lib.int32, ((),))
elem = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
enq = q.enqueue_many((elem,))
@@ -1467,7 +1467,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elem, results)
def testBigDequeueMany(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PaddingFIFOQueue(2, dtypes_lib.int32, ((),))
elem = np.arange(4, dtype=np.int32)
enq_list = [q.enqueue((e,)) for e in elem]
@@ -1493,7 +1493,7 @@ class PaddingFIFOQueueTest(test.TestCase):
self.assertAllEqual(elem, results)
def testDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dtypes = [
dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.int8, dtypes_lib.int64,
diff --git a/tensorflow/python/kernel_tests/parse_single_example_op_test.py b/tensorflow/python/kernel_tests/parse_single_example_op_test.py
index bf4c89b368..a84895a287 100644
--- a/tensorflow/python/kernel_tests/parse_single_example_op_test.py
+++ b/tensorflow/python/kernel_tests/parse_single_example_op_test.py
@@ -89,7 +89,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
class ParseExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -844,7 +844,7 @@ class ParseExampleTest(test.TestCase):
class ParseSingleExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 7dff4501cc..71d8b60d3c 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -89,7 +89,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
class ParseExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -937,7 +937,7 @@ class ParseExampleTest(test.TestCase):
class ParseSingleExampleTest(test.TestCase):
def _test(self, kwargs, expected_values=None, expected_err=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -1054,7 +1054,7 @@ class ParseSequenceExampleTest(test.TestCase):
expected_feat_list_values = expected_feat_list_values or {}
expected_length_values = expected_length_values or {}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if expected_err:
with self.assertRaisesWithPredicateMatch(expected_err[0],
expected_err[1]):
@@ -1606,7 +1606,7 @@ class ParseSequenceExampleTest(test.TestCase):
class DecodeJSONExampleTest(test.TestCase):
def _testRoundTrip(self, examples):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
examples = np.array(examples, dtype=np.object)
json_tensor = constant_op.constant(
@@ -1696,7 +1696,7 @@ class DecodeJSONExampleTest(test.TestCase):
])
def testInvalidSyntax(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
json_tensor = constant_op.constant(["{]"])
binary_tensor = parsing_ops.decode_json_example(json_tensor)
with self.assertRaisesOpError("Error while parsing JSON"):
@@ -1706,7 +1706,7 @@ class DecodeJSONExampleTest(test.TestCase):
class ParseTensorOpTest(test.TestCase):
def testToFloat32(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.float32)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1719,7 +1719,7 @@ class ParseTensorOpTest(test.TestCase):
self.assertAllEqual(expected, result)
def testToUint8(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.uint8)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1732,7 +1732,7 @@ class ParseTensorOpTest(test.TestCase):
self.assertAllEqual(expected, result)
def testTypeMismatch(self):
- with self.test_session():
+ with self.cached_session():
expected = np.random.rand(3, 4, 5).astype(np.uint8)
tensor_proto = tensor_util.make_tensor_proto(expected)
@@ -1745,7 +1745,7 @@ class ParseTensorOpTest(test.TestCase):
tensor.eval(feed_dict={serialized: tensor_proto.SerializeToString()})
def testInvalidInput(self):
- with self.test_session():
+ with self.cached_session():
serialized = array_ops.placeholder(dtypes.string)
tensor = parsing_ops.parse_tensor(serialized, dtypes.uint16)
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index 15d5702252..b34d30f5c0 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -39,7 +39,7 @@ from tensorflow.python.training import saver as saver_lib
class PartitionerCreatorsTest(test.TestCase):
def testFixedSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
with variable_scope.variable_scope("root", partitioner=partitioner):
v0 = variable_scope.get_variable(
@@ -50,7 +50,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, (5, 1))
def testFixedSizePartitionerInt64(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(4, axis=0)
with variable_scope.variable_scope("root", partitioner=partitioner):
v0 = variable_scope.get_variable("v0", dtype=dtypes.int64, shape=[20])
@@ -58,7 +58,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertEqual(len(v0_list), 4)
def testResourceFixedSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
with variable_scope.variable_scope(
"root", partitioner=partitioner, use_resource=True):
@@ -88,7 +88,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, expected_partitions)
def testVariableAxisSizePartitioner(self):
- with self.test_session():
+ with self.cached_session():
# Create a partitioned variable of shape (4, 8, 16, 32) type float32
# Bytes per slice along the given axes:
@@ -210,7 +210,7 @@ class PartitionerCreatorsTest(test.TestCase):
self.assertAllEqual(v0_part, expected_partitions)
def testMinMaxVariablePartitioner(self):
- with self.test_session():
+ with self.cached_session():
# Partitioning a variable of shape=[2048] with a minimum of 2K per slice.
self._testMinMaxVariablePartitioner(
max_partitions=100,
@@ -323,7 +323,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEquals(expected_specs[i], slices[i]._save_slice_info.spec)
def testVecConstantInit(self):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([1, 2, 3, 4])
vs = partitioned_variables.create_partitioned_variables([4], [4], rnd_par)
variables.global_variables_initializer().run()
@@ -334,7 +334,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["4 0,1", "4 1,1", "4 2,1", "4 3,1"])
def testConstantInit(self):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
vs = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
rnd_par)
@@ -346,7 +346,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["2 4 0,2:0,2", "2 4 0,2:2,2"])
def _testNameHelper(self, use_resource=False):
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with variable_scope.variable_scope("hi", use_resource=use_resource):
vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
@@ -363,7 +363,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
# Test same variable.
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with variable_scope.variable_scope(
"hola", use_resource=use_resource) as vs:
@@ -383,7 +383,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual(var2_name + "/part_0:0", vs2[0].name)
self.assertEqual(var2_name + "/part_1:0", vs2[1].name)
# Test name_scope
- with self.test_session():
+ with self.cached_session():
rnd_par = constant_op.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
with ops.name_scope("ola"):
vs1 = partitioned_variables.create_partitioned_variables([2, 4], [1, 2],
@@ -408,7 +408,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._testNameHelper(use_resource=True)
def testRandomInitValue(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([200, 40]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [1, 10], rnd.initialized_value())
@@ -425,7 +425,7 @@ class PartitionedVariablesTestCase(test.TestCase):
])
def testRandomInitUnevenPartitions(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(
random_ops.random_uniform([20, 43], dtype=dtypes.float64))
var_lists = [
@@ -463,7 +463,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, save_specs[i])
def testDegenerate(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [1, 1], rnd.initialized_value())
@@ -474,7 +474,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self._TestSaveSpec(vs, ["10 43 0,10:0,43"])
def testSliceSizeOne(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
vs = partitioned_variables.create_partitioned_variables(
rnd.get_shape(), [10, 1], rnd.initialized_value())
@@ -492,7 +492,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertAllClose([0., 1., 2., 3.], _IotaInitializer([4]))
self.assertAllClose([[0., 1.], [0., 10.], [0., 100.], [0., 1000.]],
_IotaInitializer([4, 2]))
- with self.test_session():
+ with self.cached_session():
vs = partitioned_variables.create_partitioned_variables([13, 5], [3, 1],
_IotaInitializer)
variables.global_variables_initializer().run()
@@ -506,7 +506,7 @@ class PartitionedVariablesTestCase(test.TestCase):
def testRandomInitializer(self):
# Sanity check that the slices uses a different seed when using a random
# initializer function.
- with self.test_session():
+ with self.cached_session():
var0, var1 = partitioned_variables.create_partitioned_variables(
[20, 12], [1, 2], init_ops.random_uniform_initializer())
variables.global_variables_initializer().run()
@@ -514,7 +514,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertTrue(np.linalg.norm(val0 - val1) > 1e-6)
# Negative test that proves that slices have the same values if
# the random initializer uses a seed.
- with self.test_session():
+ with self.cached_session():
var0, var1 = partitioned_variables.create_partitioned_variables(
[20, 12], [1, 2], init_ops.random_uniform_initializer(seed=201))
variables.global_variables_initializer().run()
@@ -522,7 +522,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertAllClose(val0, val1)
def testSomeErrors(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([10, 43]))
with self.assertRaises(ValueError):
partitioned_variables.create_partitioned_variables(
@@ -547,7 +547,7 @@ class PartitionedVariablesTestCase(test.TestCase):
[10, 43], [1, 50], rnd.initialized_value())
def testControlDepsNone(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
c = constant_op.constant(1.0)
with ops.control_dependencies([c]):
# d get the control dependency.
@@ -573,7 +573,7 @@ class PartitionedVariablesTestCase(test.TestCase):
self.assertEqual([], op.control_inputs)
def testConcat(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
var_x = variable_scope.get_variable(
"x",
initializer=constant_op.constant([1., 2.]),
diff --git a/tensorflow/python/kernel_tests/priority_queue_test.py b/tensorflow/python/kernel_tests/priority_queue_test.py
index 3fb9c9c468..73a9c81638 100644
--- a/tensorflow/python/kernel_tests/priority_queue_test.py
+++ b/tensorflow/python/kernel_tests/priority_queue_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.platform import test
class PriorityQueueTest(test.TestCase):
def testRoundTripInsertReadOnceSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ()))
elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -67,7 +67,7 @@ class PriorityQueueTest(test.TestCase):
self.assertEqual(missed, set())
def testRoundTripInsertMultiThreadedReadOnceSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ()))
elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -113,7 +113,7 @@ class PriorityQueueTest(test.TestCase):
self.assertEqual(missed, set())
def testRoundTripFillsCapacityMultiThreadedEnqueueAndDequeue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(10, (dtypes.int64), (()))
num_threads = 40
@@ -163,7 +163,7 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(sorted(dequeued), sorted(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadManyMultithreadedSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
num_threads = 40
@@ -219,7 +219,7 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(set(dequeued), set(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ()))
elem = np.random.randint(-5, 5, size=100).astype(np.int64)
@@ -268,7 +268,7 @@ class PriorityQueueTest(test.TestCase):
self.assertEqual(missed, set())
def testRoundTripInsertOnceReadOnceSorts(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.PriorityQueue(2000, (dtypes.string, dtypes.string), (
(), ()))
elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
@@ -289,7 +289,7 @@ class PriorityQueueTest(test.TestCase):
self.assertTrue((dv0, dv1) in allowed[e])
def testRoundTripInsertOnceReadManySorts(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
q.enqueue_many((elem, elem)).run()
@@ -297,7 +297,7 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(deq_values, sorted(elem))
def testRoundTripInsertOnceReadOnceLotsSorts(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PriorityQueue(2000, (dtypes.int64), (()))
elem = np.random.randint(-100, 100, size=1000).astype(np.int64)
q.enqueue_many((elem, elem)).run()
@@ -306,13 +306,13 @@ class PriorityQueueTest(test.TestCase):
self.assertAllEqual(deq_values, sorted(elem))
def testInsertingNonInt64Fails(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.PriorityQueue(2000, (dtypes.string), (()))
with self.assertRaises(TypeError):
q.enqueue_many((["a", "b", "c"], ["a", "b", "c"])).run()
def testInsertingNonScalarFails(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
input_priority = array_ops.placeholder(dtypes.int64)
input_other = array_ops.placeholder(dtypes.string)
q = data_flow_ops.PriorityQueue(2000, (dtypes.string,), (()))
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index 0ef6a95cfc..d199a9d9dd 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -320,6 +320,15 @@ class RandomUniformTest(RandomOpTestCommon):
error = np.abs(counts - mean)
self.assertLess(error.max(), 5 * std)
+ # Check that minval = maxval is fine iff we're producing no numbers
+ def testUniformIntsDegenerate(self):
+ for dt in dtypes.int32, dtypes.int64:
+ def sample(n):
+ return self._Sampler(n, minv=0, maxv=0, dtype=dt, use_gpu=True)()
+ self.assertEqual(sample(0).shape, (10, 0))
+ with self.assertRaisesOpError('Need minval < maxval, got 0 >= 0'):
+ sample(1)
+
# Checks that the CPU and GPU implementation returns the same results,
# given the same random seed
def testCPUGPUMatch(self):
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 8e06e1abfb..8c84b2a49f 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -146,7 +146,7 @@ class IdentityReaderTest(test.TestCase):
self.assertAllEqual(expected, v)
def testOneEpoch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.IdentityReader("test_reader")
work_completed = reader.num_work_units_completed()
produced = reader.num_records_produced()
@@ -180,7 +180,7 @@ class IdentityReaderTest(test.TestCase):
self.assertAllEqual(0, queued_length.eval())
def testMultipleEpochs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.IdentityReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
enqueue = queue.enqueue_many([["DD", "EE"]])
@@ -201,7 +201,7 @@ class IdentityReaderTest(test.TestCase):
sess.run([key, value])
def testSerializeRestore(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.IdentityReader("test_reader")
produced = reader.num_records_produced()
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
@@ -256,7 +256,7 @@ class IdentityReaderTest(test.TestCase):
reader.restore_state(b"BOGUS" + state[5:]).run()
def testReset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.IdentityReader("test_reader")
work_completed = reader.num_work_units_completed()
produced = reader.num_records_produced()
@@ -307,7 +307,7 @@ class WholeFileReaderTest(test.TestCase):
self.assertAllEqual(self._content[index], v)
def testOneEpoch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queue.enqueue_many([self._filenames]).run()
@@ -323,7 +323,7 @@ class WholeFileReaderTest(test.TestCase):
sess.run([key, value])
def testInfiniteEpochs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
enqueue = queue.enqueue_many([self._filenames])
@@ -366,7 +366,7 @@ class TextLineReaderTest(test.TestCase):
return filenames
def _testOneEpoch(self, files):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TextLineReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -391,7 +391,7 @@ class TextLineReaderTest(test.TestCase):
def testSkipHeaderLines(self):
files = self._CreateFiles()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -522,7 +522,7 @@ class FixedLengthRecordReaderTest(TFCompressionTestCase):
# gap_bytes=hop_bytes-record_bytes
def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None):
hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
record_bytes=self._record_bytes,
@@ -549,7 +549,7 @@ class FixedLengthRecordReaderTest(TFCompressionTestCase):
files,
num_overlapped_records,
encoding=None):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.FixedLengthRecordReader(
header_bytes=self._header_bytes,
record_bytes=self._record_bytes,
@@ -621,7 +621,7 @@ class TFRecordReaderTest(TFCompressionTestCase):
def testOneEpoch(self):
files = self._CreateFiles()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TFRecordReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -640,7 +640,7 @@ class TFRecordReaderTest(TFCompressionTestCase):
def testReadUpTo(self):
files = self._CreateFiles()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TFRecordReader(name="test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
batch_size = 3
@@ -670,7 +670,7 @@ class TFRecordReaderTest(TFCompressionTestCase):
options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
files = self._CreateFiles(options)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -687,7 +687,7 @@ class TFRecordReaderTest(TFCompressionTestCase):
options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
files = self._CreateFiles(options)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.TFRecordReader(name="test_reader", options=options)
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -752,7 +752,7 @@ class LMDBReaderTest(test.TestCase):
shutil.copy(path, self.db_path)
def testReadFromFile(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.LMDBReader(name="test_read_from_file")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -770,7 +770,7 @@ class LMDBReaderTest(test.TestCase):
k, v = sess.run([key, value])
def testReadFromSameFile(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
filename_queue = input_lib.string_input_producer(
@@ -789,7 +789,7 @@ class LMDBReaderTest(test.TestCase):
coord.join(threads)
def testReadFromFolder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.LMDBReader(name="test_read_from_folder")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
key, value = reader.read(queue)
@@ -807,7 +807,7 @@ class LMDBReaderTest(test.TestCase):
k, v = sess.run([key, value])
def testReadFromFileRepeatedly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
filename_queue = input_lib.string_input_producer(
[self.db_path], num_epochs=None)
diff --git a/tensorflow/python/kernel_tests/record_input_test.py b/tensorflow/python/kernel_tests/record_input_test.py
index 068860d5d4..ebb9872f22 100644
--- a/tensorflow/python/kernel_tests/record_input_test.py
+++ b/tensorflow/python/kernel_tests/record_input_test.py
@@ -44,7 +44,7 @@ class RecordInputOpTest(test.TestCase):
w.close()
def testRecordInputSimple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData("basic", 1, 1)
yield_op = data_flow_ops.RecordInput(
@@ -57,7 +57,7 @@ class RecordInputOpTest(test.TestCase):
self.assertEqual(sess.run(yield_op), b"0000000000")
def testRecordInputSimpleGzip(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData(
"basic",
1,
@@ -76,7 +76,7 @@ class RecordInputOpTest(test.TestCase):
self.assertEqual(sess.run(yield_op), b"0000000000")
def testRecordInputSimpleZlib(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData(
"basic",
1,
@@ -98,7 +98,7 @@ class RecordInputOpTest(test.TestCase):
files = 100
records_per_file = 100
batches = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData("basic", files, records_per_file)
records = data_flow_ops.RecordInput(
@@ -126,7 +126,7 @@ class RecordInputOpTest(test.TestCase):
def testDoesNotDeadlock(self):
# Iterate multiple times to cause deadlock if there is a chance it can occur
for _ in range(30):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData("basic", 1, 1)
records = data_flow_ops.RecordInput(
@@ -141,7 +141,7 @@ class RecordInputOpTest(test.TestCase):
sess.run(yield_op)
def testEmptyGlob(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
record_input = data_flow_ops.RecordInput(file_pattern="foo")
yield_op = record_input.get_yield_op()
sess.run(variables.global_variables_initializer())
@@ -152,7 +152,7 @@ class RecordInputOpTest(test.TestCase):
files = 10
records_per_file = 10
batches = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.generateTestData("basic", files, records_per_file)
records = data_flow_ops.RecordInput(
diff --git a/tensorflow/python/kernel_tests/reduce_join_op_test.py b/tensorflow/python/kernel_tests/reduce_join_op_test.py
index 663561ced7..3bb4986313 100644
--- a/tensorflow/python/kernel_tests/reduce_join_op_test.py
+++ b/tensorflow/python/kernel_tests/reduce_join_op_test.py
@@ -113,7 +113,7 @@ class ReduceJoinTest(UnicodeTestCase):
keep_dims: Whether or not to retain reduced dimensions.
separator: The separator to use for joining.
"""
- with self.test_session():
+ with self.cached_session():
output = string_ops.reduce_join(
inputs=input_array,
axis=axis,
@@ -136,7 +136,7 @@ class ReduceJoinTest(UnicodeTestCase):
axis: The indices to reduce.
separator: The separator to use when joining.
"""
- with self.test_session():
+ with self.cached_session():
output = string_ops.reduce_join(
inputs=input_array, axis=axis, keep_dims=False, separator=separator)
output_keep_dims = string_ops.reduce_join(
@@ -234,7 +234,7 @@ class ReduceJoinTest(UnicodeTestCase):
input_array = [["a"], ["b"]]
truth = ["ab"]
truth_shape = None
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.string, name="placeholder")
reduced = string_ops.reduce_join(placeholder, axis=0)
output_array = reduced.eval(feed_dict={placeholder.name: input_array})
@@ -247,7 +247,7 @@ class ReduceJoinTest(UnicodeTestCase):
truth_dim_zero = ["thisplease", "isdo", "anot", "testpanic"]
truth_dim_one = ["thisisatest", "pleasedonotpanic"]
truth_shape = None
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.int32, name="placeholder")
reduced = string_ops.reduce_join(input_array, axis=placeholder)
output_array_dim_zero = reduced.eval(feed_dict={placeholder.name: [0]})
@@ -298,7 +298,7 @@ class ReduceJoinTest(UnicodeTestCase):
self._testMultipleReduceJoin(input_array, axis=permutation)
def testInvalidReductionIndices(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "Invalid reduction dim"):
string_ops.reduce_join(inputs="", axis=0)
with self.assertRaisesRegexp(ValueError,
@@ -313,7 +313,7 @@ class ReduceJoinTest(UnicodeTestCase):
string_ops.reduce_join(inputs=[[""]], axis=[0, 2])
def testZeroDims(self):
- with self.test_session():
+ with self.cached_session():
inputs = np.zeros([0, 1], dtype=str)
# Reduction that drops the dim of size 0.
@@ -326,7 +326,7 @@ class ReduceJoinTest(UnicodeTestCase):
self.assertAllEqual([0], output_shape)
def testInvalidArgsUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.string, name="placeholder")
index_too_high = string_ops.reduce_join(placeholder, axis=1)
duplicate_index = string_ops.reduce_join(placeholder, axis=[-1, 1])
@@ -336,7 +336,7 @@ class ReduceJoinTest(UnicodeTestCase):
duplicate_index.eval(feed_dict={placeholder.name: [[""]]})
def testInvalidArgsUnknownIndices(self):
- with self.test_session():
+ with self.cached_session():
placeholder = array_ops.placeholder(dtypes.int32, name="placeholder")
reduced = string_ops.reduce_join(["test", "test2"], axis=placeholder)
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index ea78b58d88..248036a82a 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -61,7 +61,7 @@ class ReducedShapeTest(test.TestCase):
self.assertAllEqual(output.eval(), result)
def testSimple(self):
- with self.test_session():
+ with self.cached_session():
self._check([3], [], [3])
self._check([3], [0], [1])
self._check([5, 3], [], [5, 3])
@@ -71,7 +71,7 @@ class ReducedShapeTest(test.TestCase):
def testZeros(self):
"""Check that reduced_shape does the right thing with zero dimensions."""
- with self.test_session():
+ with self.cached_session():
self._check([0], [], [0])
self._check([0], [0], [1])
self._check([0, 3], [], [0, 3])
@@ -84,7 +84,7 @@ class ReducedShapeTest(test.TestCase):
self._check([3, 0], [0, 1], [1, 1])
def testNegAxes(self):
- with self.test_session():
+ with self.cached_session():
self._check([10, 10, 10], [-1], [10, 10, 1])
self._check([10, 10, 10], [-1, 2], [10, 10, 1])
self._check([10, 10, 10], [-1, -1], [10, 10, 1])
@@ -95,7 +95,7 @@ class ReducedShapeTest(test.TestCase):
class ReductionUnknownShape(test.TestCase):
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
for dtype, reductions in [(dtypes.float32,
(math_ops.reduce_sum, math_ops.reduce_mean,
math_ops.reduce_prod, math_ops.reduce_max,
@@ -212,7 +212,7 @@ class SumReductionTest(BaseReductionTest):
arr = np.ones([68000], dtype=np.float16)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_arr = variables.Variable(arr)
variables.global_variables_initializer().run()
tf_mean = math_ops.reduce_mean(tf_arr, 0, False)
@@ -235,7 +235,7 @@ class SumReductionTest(BaseReductionTest):
col_sum = np.sum(arr, axis=0)
row_sum = np.sum(arr, axis=1)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_sum = self._tf_reduce(arr, 1, False)
tf_col_sum = self._tf_reduce(arr, 0, False)
tf_out_row, tf_out_col = sess.run([tf_row_sum, tf_col_sum])
@@ -249,7 +249,7 @@ class SumReductionTest(BaseReductionTest):
sum_y = np.sum(arr, axis=1)
sum_xz = np.sum(arr, axis=(0, 2))
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce(arr, [0, 2], False)
tf_sum_y = self._tf_reduce(arr, 1, False)
tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
@@ -617,7 +617,7 @@ class MinReductionTest(test.TestCase):
def testGradient(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_min(t, [1, 2])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -627,7 +627,7 @@ class MinReductionTest(test.TestCase):
def testGradient2(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_min(t, [1])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -637,7 +637,7 @@ class MinReductionTest(test.TestCase):
def testGradient3(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_min(t, [2])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -647,7 +647,7 @@ class MinReductionTest(test.TestCase):
def testGradient4(self):
s = [2, 3, 4, 2]
x = np.arange(1.0, 49.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_min(t)
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -655,7 +655,7 @@ class MinReductionTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testEmptyGradients(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.zeros([0, 3])
y = math_ops.reduce_min(x, [1])
error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -744,7 +744,7 @@ class MaxReductionTest(test.TestCase):
def testGradient(self):
s = [2, 3, 4, 2]
x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_max(t, [1, 2])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -754,7 +754,7 @@ class MaxReductionTest(test.TestCase):
def testGradient2(self):
s = [2, 3, 4, 2]
x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_max(t, [1])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -764,7 +764,7 @@ class MaxReductionTest(test.TestCase):
def testGradient3(self):
s = [2, 3, 4, 2]
x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_max(t, [2])
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -774,7 +774,7 @@ class MaxReductionTest(test.TestCase):
def testGradient4(self):
s = [2, 3, 4, 2]
x = np.arange(-49.0, -1.0).reshape(s).astype(np.float64)
- with self.test_session():
+ with self.cached_session():
t = ops.convert_to_tensor(x)
su = math_ops.reduce_max(t)
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -782,7 +782,7 @@ class MaxReductionTest(test.TestCase):
self.assertAllClose(jacob_t, jacob_n, rtol=1e-8, atol=1e-8)
def testEmptyGradients(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.zeros([0, 3])
y = math_ops.reduce_max(x, [1])
error = gradient_checker.compute_gradient_error(x, [0, 3], y, [0])
@@ -960,7 +960,7 @@ class CountNonzeroReductionTest(test.TestCase):
def testStringReduce(self):
# Test case for GitHub issue 18712
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v = math_ops.count_nonzero(constant_op.constant(["test"]))
self.assertAllClose(sess.run(v), 1)
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test_big.py b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
index d70360775a..1e8524f72a 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test_big.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
@@ -63,7 +63,7 @@ class BigReductionTest(BaseReductionTest):
row_sum = np.ones([size_x], dtype=np.float32) * size_y
full_sum = np.ones([], dtype=np.float32) * size_x * size_y
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_sum = self._tf_reduce_sum(arr, 1, False)
tf_col_sum = self._tf_reduce_sum(arr, 0, False)
tf_full_sum = self._tf_reduce_sum(arr, [0, 1], False)
@@ -81,7 +81,7 @@ class BigReductionTest(BaseReductionTest):
sum_y = np.ones([size_x, size_z], dtype=np.float32)
sum_xz = np.ones([size_y], dtype=np.float32)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce_mean(arr, [0, 2], False)
tf_sum_y = self._tf_reduce_mean(arr, 1, False)
tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
@@ -106,7 +106,7 @@ class BigReductionTest(BaseReductionTest):
row_max = np.max(arr, axis=1)
full_max = np.max(col_max)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_max = self._tf_reduce_max(arr, 1, False)
tf_col_max = self._tf_reduce_max(arr, 0, False)
tf_full_max = self._tf_reduce_max(arr, [0, 1], False)
@@ -125,7 +125,7 @@ class BigReductionTest(BaseReductionTest):
sum_y = np.max(arr, axis=1)
sum_xz = np.max(arr, axis=(0, 2))
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce_max(arr, [0, 2], False)
tf_sum_y = self._tf_reduce_max(arr, 1, False)
tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
@@ -149,7 +149,7 @@ class BigReductionTest(BaseReductionTest):
row_sum = np.ones([size_x], dtype=np.bool)
full_sum = np.ones([1], dtype=np.bool).reshape([])
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_sum = self._tf_reduce_all(arr, 1, False)
tf_col_sum = self._tf_reduce_all(arr, 0, False)
tf_full_sum = self._tf_reduce_all(arr, [0, 1], False)
@@ -167,7 +167,7 @@ class BigReductionTest(BaseReductionTest):
sum_y = np.ones([size_x, size_z], dtype=np.bool)
sum_xz = np.ones([size_y], dtype=np.bool)
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce_all(arr, [0, 2], False)
tf_sum_y = self._tf_reduce_all(arr, 1, False)
tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
diff --git a/tensorflow/python/kernel_tests/regex_full_match_op_test.py b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
index 7bd8c3ca27..98746e7d9b 100644
--- a/tensorflow/python/kernel_tests/regex_full_match_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_full_match_op_test.py
@@ -35,28 +35,28 @@ class RegexFullMatchOpVariantsTest(test.TestCase, parameterized.TestCase):
def testRegexFullMatch(self, op):
values = ["abaaba", "abcdabcde"]
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([True, False], matched)
def testRegexFullMatchTwoDims(self, op):
values = [["abaaba", "abcdabcde"], ["acdcba", "ebcda"]]
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "a.*a").eval()
self.assertAllEqual([[True, False], [True, False]], matched)
def testEmptyMatch(self, op):
values = ["abc", "1"]
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(values, dtypes.string)
matched = op(input_tensor, "").eval()
self.assertAllEqual([False, False], matched)
def testInvalidPattern(self, op):
values = ["abc", "1"]
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
matched = op(input_tensor, invalid_pattern)
@@ -68,7 +68,7 @@ class RegexFullMatchOpTest(test.TestCase):
def testRegexFullMatchDelegation(self):
with compat.forward_compatibility_horizon(2018, 11, 1):
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant("foo", dtypes.string)
pattern = "[a-z]"
op = string_ops.regex_full_match(input_tensor, pattern)
@@ -80,7 +80,7 @@ class RegexFullMatchOpTest(test.TestCase):
def testStaticRegexFullMatchDelegation(self):
with compat.forward_compatibility_horizon(2018, 11, 20):
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant("foo", dtypes.string)
pattern = "[a-z]*"
op = string_ops.regex_full_match(input_tensor, pattern)
diff --git a/tensorflow/python/kernel_tests/regex_replace_op_test.py b/tensorflow/python/kernel_tests/regex_replace_op_test.py
index f0e84b8fca..d9b7ed28d2 100644
--- a/tensorflow/python/kernel_tests/regex_replace_op_test.py
+++ b/tensorflow/python/kernel_tests/regex_replace_op_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import gen_string_ops
@@ -34,7 +33,7 @@ from tensorflow.python.platform import test
class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
def testForwarding(self, op):
- with self.test_session():
+ with self.cached_session():
# Generate an input that is uniquely consumed by the regex op.
# This exercises code paths which are optimized for this case
# (e.g., using forwarding).
@@ -48,7 +47,7 @@ class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
def testRemovePrefix(self, op):
values = ["a:foo", "a:bar", "a:foo", "b:baz", "b:qux", "ca:b"]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
stripped = op(input_vector, "^(a:|b:)", "", replace_global=False).eval()
self.assertAllEqual([b"foo", b"bar", b"foo", b"baz", b"qux", b"ca:b"],
@@ -56,21 +55,21 @@ class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
def testRegexReplace(self, op):
values = ["aba\naba", "abcdabcde"]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
stripped = op(input_vector, "a.*a", "(\\0)").eval()
self.assertAllEqual([b"(aba)\n(aba)", b"(abcda)bcde"], stripped)
def testEmptyMatch(self, op):
values = ["abc", "1"]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
stripped = op(input_vector, "", "x").eval()
self.assertAllEqual([b"xaxbxcx", b"x1x"], stripped)
def testInvalidPattern(self, op):
values = ["abc", "1"]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
invalid_pattern = "A["
replace = op(input_vector, invalid_pattern, "x")
@@ -79,7 +78,7 @@ class RegexReplaceOpVariantsTest(test.TestCase, parameterized.TestCase):
def testGlobal(self, op):
values = ["ababababab", "abcabcabc", ""]
- with self.test_session():
+ with self.cached_session():
input_vector = constant_op.constant(values, dtypes.string)
stripped = op(input_vector, "ab", "abc", True).eval()
self.assertAllEqual([b"abcabcabcabcabc", b"abccabccabcc", b""], stripped)
@@ -100,22 +99,20 @@ class RegexReplaceTest(test.TestCase, parameterized.TestCase):
(as_tensor, as_string),
(as_tensor, as_tensor))
def testRegexReplaceDelegation(self, pattern_fn, rewrite_fn):
- with compat.forward_compatibility_horizon(2018, 10, 11):
- with self.test_session():
- input_vector = constant_op.constant("foo", dtypes.string)
- pattern = pattern_fn("[a-z]")
- replace = rewrite_fn(".")
- op = string_ops.regex_replace(input_vector, pattern, replace)
- self.assertTrue(op.name.startswith("RegexReplace"))
+ with self.cached_session():
+ input_vector = constant_op.constant("foo", dtypes.string)
+ pattern = pattern_fn("[a-z]")
+ replace = rewrite_fn(".")
+ op = string_ops.regex_replace(input_vector, pattern, replace)
+ self.assertTrue(op.name.startswith("RegexReplace"))
def testStaticRegexReplaceDelegation(self):
- with compat.forward_compatibility_horizon(2018, 10, 11):
- with self.test_session():
- input_vector = constant_op.constant("foo", dtypes.string)
- pattern = "[a-z]"
- replace = "."
- op = string_ops.regex_replace(input_vector, pattern, replace)
- self.assertTrue(op.name.startswith("StaticRegexReplace"))
+ with self.cached_session():
+ input_vector = constant_op.constant("foo", dtypes.string)
+ pattern = "[a-z]"
+ replace = "."
+ op = string_ops.regex_replace(input_vector, pattern, replace)
+ self.assertTrue(op.name.startswith("StaticRegexReplace"))
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index d97a1613b9..b26e944af8 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -105,7 +105,7 @@ class ReluTest(test.TestCase):
# The gradient test for ReLU is a bit tricky as the derivative is not well
# defined at around zero and we want to avoid that in terms of input values.
def testGradientFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -150,7 +150,7 @@ class ReluTest(test.TestCase):
self.assertAllClose(dx_f32_v, dx_f16_v, atol=3e-4)
def testGradientFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -167,7 +167,7 @@ class ReluTest(test.TestCase):
self.assertLess(err, 1e-10)
def testGradGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -184,7 +184,7 @@ class ReluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -202,7 +202,7 @@ class ReluTest(test.TestCase):
self.assertLess(err, 1e-10)
def testGradientScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = variables.Variable(100.)
y = nn_ops.relu(x)
loss = y**2
@@ -250,7 +250,7 @@ class Relu6Test(test.TestCase):
# not well defined at around zero and six and we want to avoid that
# in terms of input values.
def testGradientFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
shape=[2, 5],
@@ -266,7 +266,7 @@ class Relu6Test(test.TestCase):
self.assertLess(err, 1e-4)
def testGradientFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 6.1, 6.3, 6.5, 6.7, 6.9],
shape=[2, 5],
@@ -430,7 +430,7 @@ class EluTest(test.TestCase):
use_gpu=True)
def testGradientFloat32(self):
- with self.test_session():
+ with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, name="x")
y = nn_ops.elu(x, name="elu")
@@ -441,7 +441,7 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradientFloat64(self):
- with self.test_session():
+ with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
y = nn_ops.elu(x, name="elu")
@@ -452,7 +452,7 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-6)
def testGradGrad(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtype=dtypes.float32)
elu = nn_ops.elu(x)
g, = gradients_impl.gradients(elu, x)
@@ -463,7 +463,7 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -480,7 +480,7 @@ class EluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -532,7 +532,7 @@ class SeluTest(test.TestCase):
use_gpu=True)
def testGradientFloat32(self):
- with self.test_session():
+ with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, name="x")
y = nn_ops.selu(x, name="selu")
@@ -543,7 +543,7 @@ class SeluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradientFloat64(self):
- with self.test_session():
+ with self.cached_session():
x_val = [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]]
x = constant_op.constant(x_val, dtype=dtypes.float64, name="x")
y = nn_ops.selu(x, name="selu")
@@ -554,7 +554,7 @@ class SeluTest(test.TestCase):
self.assertLess(err, 1e-6)
def testGradGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -571,7 +571,7 @@ class SeluTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGradFloat64(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -620,7 +620,7 @@ class CreluTest(test.TestCase):
use_gpu=True)
def testNumbersWithAxis0(self):
- with self.test_session():
+ with self.cached_session():
crelu = nn_ops.crelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=0)
tf_relu = crelu.eval()
@@ -629,7 +629,7 @@ class CreluTest(test.TestCase):
self.assertAllEqual(np_crelu, tf_relu)
def testNumbersWithAxis1(self):
- with self.test_session():
+ with self.cached_session():
crelu = nn_ops.crelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]), axis=1)
tf_relu = crelu.eval()
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index ef9b439230..ca3ff1d1df 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -94,7 +94,7 @@ class ReshapeTest(test.TestCase):
def testFloatReshapeGradThreeDimensions(self):
x = np.arange(1., 25.).reshape([2, 3, 4]).astype(np.float32)
s = list(np.shape(x))
- with self.test_session():
+ with self.cached_session():
input_tensor = constant_op.constant(x)
reshape_out = array_ops.reshape(input_tensor, [1, 8, 3])
err = gradient_checker.compute_gradient_error(
diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
index 9beb615b2c..8fc71e0c57 100644
--- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
+++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
@@ -120,7 +120,7 @@ class ReverseSequenceTest(test.TestCase):
batch_axis = 2
seq_lengths = np.asarray([3, 0, 4], dtype=np.int64)
- with self.test_session():
+ with self.cached_session():
input_t = constant_op.constant(x, shape=x.shape)
seq_lengths_t = constant_op.constant(seq_lengths, shape=seq_lengths.shape)
reverse_sequence_out = array_ops.reverse_sequence(
@@ -171,7 +171,7 @@ class ReverseSequenceTest(test.TestCase):
seq_axis=0,
batch_axis=3)
- with self.test_session():
+ with self.cached_session():
inputs = array_ops.placeholder(dtypes.float32, shape=(32, 2, 3))
seq_lengths = array_ops.placeholder(dtypes.int64, shape=(32,))
output = array_ops.reverse_sequence(
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index a28cdc3b26..05ad9f6336 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -516,7 +516,7 @@ class RNNTest(test.TestCase):
fix_weights_generator.build((None, input_shape))
weights = fix_weights_generator.get_weights()
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
inputs = array_ops.placeholder(
dtypes.float32, shape=(None, timestep, input_shape))
cell = keras.layers.SimpleRNNCell(output_shape)
@@ -524,7 +524,7 @@ class RNNTest(test.TestCase):
cell, inputs, dtype=dtypes.float32)
cell.set_weights(weights)
[tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
k_input = keras.Input(shape=(timestep, input_shape),
dtype=dtypes.float32)
cell = keras.layers.SimpleRNNCell(output_shape)
@@ -536,7 +536,7 @@ class RNNTest(test.TestCase):
self.assertAllClose(tf_state, k_state)
def testBasicLSTMCellInterchangeWithLSTMCell(self):
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
basic_cell = rnn_cell_impl.BasicLSTMCell(1)
basic_cell(array_ops.ones([1, 1]),
state=basic_cell.get_initial_state(inputs=None,
@@ -548,7 +548,7 @@ class RNNTest(test.TestCase):
prefix = os.path.join(self.get_temp_dir(), "ckpt")
save_path = save.save(sess, prefix)
- with self.test_session(graph=ops_lib.Graph()) as sess:
+ with self.session(graph=ops_lib.Graph()) as sess:
lstm_cell = rnn_cell_impl.LSTMCell(1, name="basic_lstm_cell")
lstm_cell(array_ops.ones([1, 1]),
state=lstm_cell.get_initial_state(inputs=None,
diff --git a/tensorflow/python/kernel_tests/scalar_test.py b/tensorflow/python/kernel_tests/scalar_test.py
index 287919bab7..d15f2c7b50 100644
--- a/tensorflow/python/kernel_tests/scalar_test.py
+++ b/tensorflow/python/kernel_tests/scalar_test.py
@@ -53,7 +53,7 @@ class ScalarTest(test.TestCase):
for version in strict + lenient:
with ops.Graph().as_default() as g:
test_util.set_producer_version(g, version)
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
feed = {}
xs = placeholders(args, feed)
x = op(*xs)
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index f2f3023469..86e063cb36 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -294,7 +294,7 @@ class StatefulScatterNdTest(test.TestCase):
self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
- with self.test_session():
+ with self.cached_session():
ref.initializer.run()
self.assertAllEqual(expected_result, scatter_update.eval())
@@ -409,7 +409,7 @@ class ScatterNdTest(test.TestCase):
expected = np.array([b"", b"one", b"", b"three", b"four",
b"", b"", b"seven"])
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertAllEqual(expected, result)
@@ -420,7 +420,7 @@ class ScatterNdTest(test.TestCase):
dtype=dtypes.string)
expected = np.array([b"", b"", b"", b"bb", b"a", b"", b"", b"c"])
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertAllEqual(expected, result)
@@ -432,7 +432,7 @@ class ScatterNdTest(test.TestCase):
expected = [np.array([b"", b"", b"", b"bc", b"a", b"", b"", b"d"]),
np.array([b"", b"", b"", b"cb", b"a", b"", b"", b"d"])]
scatter = self.scatter_nd(indices, updates, shape=(8,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = sess.run(scatter)
self.assertTrue(np.array_equal(result, expected[0]) or
np.array_equal(result, expected[1]))
@@ -451,7 +451,7 @@ class ScatterNdTest(test.TestCase):
scatter = self.scatter_nd(indices, updates, shape)
self.assertAllEqual(scatter.get_shape().as_list(), shape)
expected_result = np.zeros([2, 2], dtype=np.int32)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_result, scatter.eval())
def testUndefinedIndicesShape(self):
@@ -486,7 +486,7 @@ class ScatterNdTest(test.TestCase):
updates = array_ops.placeholder(dtypes.int32, shape=None)
shape = constant_op.constant([0, 3, 2], dtypes.int32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(
"Indices and updates specified for empty output"):
self.scatter_nd(indices, updates, shape).eval(feed_dict={
@@ -500,7 +500,7 @@ class ScatterNdTest(test.TestCase):
shape = constant_op.constant([0], dtypes.int32)
scatter = self.scatter_nd(indices, updates, shape)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(scatter.eval().size, 0)
def testRank3InvalidShape1(self):
@@ -531,7 +531,7 @@ class ScatterNdTest(test.TestCase):
[outputs], [updates, input_], [grad_vals])
expected_updates_grad = np.array([1, 4], dtype=np.float64)
expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -548,7 +548,7 @@ class ScatterNdTest(test.TestCase):
[outputs], [updates, input_], [grad_vals])
expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -570,7 +570,7 @@ class ScatterNdTest(test.TestCase):
[[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64)
expected_input_grad = np.array(
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -607,7 +607,7 @@ class ScatterNdTest(test.TestCase):
[[[[1, 2], [3, 4]]]],
[[[[5, 6], [7, 8]]]]
]]], dtype=np.float64)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_updates_grad, updates_grad.eval())
if self.non_aliasing_add_test:
self.assertAllEqual(expected_input_grad, input_grad.eval())
@@ -616,33 +616,33 @@ class ScatterNdTest(test.TestCase):
indices = array_ops.zeros([100000, 1], dtypes.int32)
values = np.random.randn(100000)
shape = [1]
- with self.test_session():
+ with self.cached_session():
val = self.scatter_nd(indices, values, shape).eval()
self.assertAllClose([np.sum(values)], val)
def testSmokeScatterNdBatch2DSliceDim2(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32)
values = array_ops.zeros([3, 5, 7])
shape = [4, 6, 7]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch1DSliceDim2(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([0, 2], dtype=dtypes.int32)
values = array_ops.zeros([0, 7])
shape = [4, 6, 7]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([1, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
self.scatter_nd(indices, values, shape).eval()
def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self):
- with self.test_session():
+ with self.cached_session():
indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32)
values = array_ops.zeros([1, 2, 6, 7, 8, 9])
shape = [3, 4, 5, 6, 7, 8, 9]
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index a82855dfeb..2931877c11 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -177,7 +177,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
def testSegmentIdsInvalid1(self):
shape = [4, 4]
- with self.test_session():
+ with self.cached_session():
tf_x, _ = self._input(shape)
indices = [-1, -1, 0, 0]
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -188,7 +188,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
def testSegmentIdsInvalid2(self):
shape = [4, 4]
- with self.test_session():
+ with self.cached_session():
tf_x, _ = self._input(shape)
indices = [0, 1, 0, 1]
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -197,7 +197,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
def testSegmentIdsInvalid3(self):
shape = [4, 4]
- with self.test_session():
+ with self.cached_session():
tf_x, _ = self._input(shape)
indices = [0, 1, 2, 0]
s = math_ops.segment_sum(data=tf_x, segment_ids=indices)
@@ -233,7 +233,7 @@ class SegmentReductionOpTest(SegmentReductionHelper):
math_ops.segment_sum, math_ops.segment_mean, math_ops.segment_min,
math_ops.segment_max
]:
- with self.test_session():
+ with self.cached_session():
tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
s = tf_op(data=tf_x, segment_ids=indices)
jacob_t, jacob_n = gradient_checker.compute_gradient(
@@ -300,7 +300,7 @@ class UnsortedSegmentTest(SegmentReductionHelper):
tf_ans = s.eval()
if dtype is dtypes_lib.bfloat16:
tf_ans = tf_ans.astype(np.float32)
- self.assertAllClose(np_ans, tf_ans)
+ self.assertAllCloseAccordingToType(np_ans, tf_ans)
self.assertShapeEqual(np_ans, s)
def testNumSegmentsTypes(self):
@@ -736,7 +736,7 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
segment_indices = [0, 1, 2, 2]
num_indices = len(segment_indices)
for tf_op in [math_ops.sparse_segment_sum, math_ops.sparse_segment_mean]:
- with self.test_session():
+ with self.cached_session():
tf_indices, _, tf_x, np_x = self._sparse_input(
shape, num_indices, dtype=dtypes_lib.float64)
s = tf_op(data=tf_x, indices=tf_indices, segment_ids=segment_indices)
@@ -758,7 +758,7 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
math_ops.sparse_segment_sum_with_num_segments,
math_ops.sparse_segment_mean_with_num_segments,
]:
- with self.test_session():
+ with self.cached_session():
tf_indices, _, tf_x, np_x = self._sparse_input(
shape, num_indices, dtype=dtypes_lib.float64)
s = tf_op(
diff --git a/tensorflow/python/kernel_tests/session_ops_test.py b/tensorflow/python/kernel_tests/session_ops_test.py
index 678016b13d..03e1ae852f 100644
--- a/tensorflow/python/kernel_tests/session_ops_test.py
+++ b/tensorflow/python/kernel_tests/session_ops_test.py
@@ -31,7 +31,7 @@ from tensorflow.python.platform import test
class SessionOpsTest(test.TestCase):
def testHandleBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -45,7 +45,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(500, sess.run(y, feed_dict={f: h.handle}))
def testHandleEval(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -57,7 +57,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(50, h.eval())
def testHandleAndValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle and a value.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -70,7 +70,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(500, v)
def testHandleCond(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle and a value
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -90,7 +90,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(5000, result)
def testHandleForLoop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize a handle.
a = constant_op.constant(0)
h = session_ops.get_session_handle(a)
@@ -107,7 +107,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(100, h.eval())
def testHandleWhileLoop(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize a handle.
a = constant_op.constant(0)
h = session_ops.get_session_handle(a)
@@ -127,7 +127,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(101, h.eval())
def testHandleMover(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -148,7 +148,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(100, sess.run(y, feed_dict={f: h.handle}))
def testHandleDelete(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -157,7 +157,7 @@ class SessionOpsTest(test.TestCase):
sess.run(h).delete()
def testHandleDeleteRaw(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Return a handle.
a = constant_op.constant(10)
b = constant_op.constant(5)
@@ -171,7 +171,7 @@ class SessionOpsTest(test.TestCase):
sess.run(x, feed_dict={f: raw_h})
def testMultiDevices(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device(test.gpu_device_name()):
a = constant_op.constant(1.0)
a_handle = sess.run(session_ops.get_session_handle(a))
@@ -189,7 +189,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(3.0, c_handle.eval())
def testHandleGC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# initial values live on CPU
with ops.device("/cpu:0"):
one = constant_op.constant(1, dtype=dtypes.float32)
@@ -213,7 +213,7 @@ class SessionOpsTest(test.TestCase):
add_h2: x_handle.handle})
def testHandlePlacement(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(1.0)
a_handle_op = session_ops.get_session_handle(a)
b = constant_op.constant(2.0)
@@ -233,7 +233,7 @@ class SessionOpsTest(test.TestCase):
self.assertEqual(3.0, c_handle.eval())
def testFeedOneHandleDirectly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(10.0)
b = constant_op.constant(5.0)
c = math_ops.multiply(a, b)
@@ -244,7 +244,7 @@ class SessionOpsTest(test.TestCase):
self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c}))
def testDirectHandleFeedOverlappingWithFetches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(10.0)
b = constant_op.constant(5.0)
c = math_ops.multiply(a, b)
@@ -270,7 +270,7 @@ class SessionOpsTest(test.TestCase):
self.assertAllClose(50.0, d_val)
def testFeedTwoHandlesDirectly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(10.0)
b = constant_op.constant(5.0)
c = math_ops.multiply(a, b)
@@ -284,7 +284,7 @@ class SessionOpsTest(test.TestCase):
self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c}))
def testFeedHandleToVariableDirectly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = variables.Variable(12.0)
inc_a = state_ops.assign_add(a, 2.0)
b = math_ops.add(a, 5.0)
diff --git a/tensorflow/python/kernel_tests/sets_test.py b/tensorflow/python/kernel_tests/sets_test.py
index 52b723802f..8335e9c139 100644
--- a/tensorflow/python/kernel_tests/sets_test.py
+++ b/tensorflow/python/kernel_tests/sets_test.py
@@ -158,7 +158,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
for op in ops:
self.assertEqual(None, op.get_shape().dims)
self.assertEqual(dtypes.int32, op.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
results = sess.run(ops)
self.assertAllEqual(results[0], results[1])
return results[0]
@@ -477,7 +477,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
dynamic_values_shape_ops = []
static_indices_shape = None
static_values_shape = None
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for op in ops:
if static_indices_shape is None:
static_indices_shape = op.indices.get_shape()
@@ -533,7 +533,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
def _set_intersection_count(self, a, b):
op = sets.set_size(sets.set_intersection(a, b))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(op)
def test_set_difference_multirow_2d(self):
@@ -971,7 +971,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
def _set_difference_count(self, a, b, aminusb=True):
op = sets.set_size(sets.set_difference(a, b, aminusb))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(op)
def test_set_union_multirow_2d(self):
@@ -1220,7 +1220,7 @@ class SetOpsTest(test_util.TensorFlowTestCase):
def _set_union_count(self, a, b):
op = sets.set_size(sets.set_union(a, b))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
return sess.run(op)
def _assert_set_operation(self, expected_indices, expected_values,
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 34e34d9d1b..0304dc3875 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -158,7 +158,7 @@ class ShapeOpsTest(test.TestCase):
# Disabled because it takes too long to run, but manually verified
# as passing at time of writing.
def _test64BitOutput(self):
- with self.test_session():
+ with self.cached_session():
inp = array_ops.zeros([2**31])
num_elements = array_ops.size_internal(
inp, optimize=False, out_type=dtypes.int64)
@@ -166,7 +166,7 @@ class ShapeOpsTest(test.TestCase):
# Too large for tf.int32 output.
with self.assertRaises(errors_impl.InvalidArgumentError):
- with self.test_session():
+ with self.cached_session():
inp = array_ops.zeros([2**31])
num_elements = array_ops.size_internal(
inp, optimize=False, out_type=dtypes.int32)
@@ -228,7 +228,7 @@ class ShapeOpsTest(test.TestCase):
self._compareExpandDimsAll(choice([2, 3, 5]), -4)
def testExpandDimsErrors(self):
- with self.test_session():
+ with self.cached_session():
self.assertRaises(ValueError, array_ops.expand_dims,
np.zeros([2, 3, 5]), -5)
self.assertRaises(ValueError, array_ops.expand_dims,
@@ -239,7 +239,7 @@ class ShapeOpsTest(test.TestCase):
[False, True, True], 4)
def testExpandDimsGradient(self):
- with self.test_session():
+ with self.cached_session():
inp = constant_op.constant(
np.random.rand(4, 2).astype("f"), dtype=dtypes.float32)
squeezed = array_ops.expand_dims(inp, 1)
@@ -249,7 +249,7 @@ class ShapeOpsTest(test.TestCase):
self.assertLess(err, 1e-3)
def testExpandDimsScalar(self):
- with self.test_session():
+ with self.cached_session():
inp = constant_op.constant(7)
self.assertAllEqual([7], array_ops.expand_dims(inp, 0).eval())
self.assertAllEqual([7], array_ops.expand_dims(inp, -1).eval())
@@ -375,7 +375,7 @@ class ShapeOpsTest(test.TestCase):
np.zeros([1, 2, 1]), [2, 3])
def testSqueezeGradient(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 2).astype("f")
a = array_ops.reshape(inp, [4, 1, 2])
squeezed = array_ops.squeeze(a, [])
@@ -385,7 +385,7 @@ class ShapeOpsTest(test.TestCase):
self.assertLess(err, 1e-3)
def testSqueezeGradientWithSqueezeDims(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 2).astype("f")
a = array_ops.reshape(inp, [4, 1, 2, 1])
squeezed = array_ops.squeeze(a, [1])
@@ -395,7 +395,7 @@ class ShapeOpsTest(test.TestCase):
self.assertLess(err, 1e-3)
def testSqueezeWithUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtypes.float32, shape=[2, None])
squeezed = array_ops.squeeze(a, [1])
@@ -433,7 +433,7 @@ class TileTest(test.TestCase):
self.assertTrue((result == np.tile(inp, (1, 4))).all())
def testIdentityTileAndGrad(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 1).astype(np.float32)
a = constant_op.constant(inp)
tiled = array_ops.tile(a, [1, 1])
@@ -443,7 +443,7 @@ class TileTest(test.TestCase):
self.assertTrue((result == np.tile(inp, (1, 1))).all())
def testEmpty(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(2, 3).astype(np.float32)
a = constant_op.constant(inp)
tiled = array_ops.tile(a, [5, 0])
@@ -453,7 +453,7 @@ class TileTest(test.TestCase):
def testUnknownInputShape(self):
"""Importing can call _TileShape without shape of <multiples> known."""
- with self.test_session():
+ with self.cached_session():
inp = array_ops.placeholder(dtypes.float32) # unknown shape
multiples = constant_op.constant([1, 2, 3, 4], dtype=np.int32)
tiled = array_ops.tile(inp, multiples)
@@ -503,7 +503,7 @@ class TileTest(test.TestCase):
self.assertAllEqual(result, np.tile(inp, (1, 4)))
def testInvalidDim(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 1).astype("f")
a = constant_op.constant(
[float(x) for x in inp.ravel(order="C")],
@@ -546,7 +546,7 @@ class TileTest(test.TestCase):
self._RunAndVerifyResult(10, use_gpu=True)
def testGradientSimpleReduction(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 1).astype("f")
a = constant_op.constant(
[float(x) for x in inp.flatten()], shape=[4, 1], dtype=dtypes.float32)
@@ -561,7 +561,7 @@ class TileTest(test.TestCase):
self.assertAllClose(np.sum(grad_inp, axis=1).reshape(4, 1), result, 1e-3)
def testGradientStridedReduction(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 2).astype("f")
a = constant_op.constant(
[float(x) for x in inp.flatten()], shape=[4, 2], dtype=dtypes.float32)
@@ -634,7 +634,7 @@ class TileTest(test.TestCase):
self._RunAndVerifyGradientResult([2, 1, 3, 3, 2], [1, 3, 3, 1, 2])
def testGradientStridedReductionGC(self):
- with self.test_session():
+ with self.cached_session():
inp = np.random.rand(4, 2).astype("f")
a = constant_op.constant(
[float(x) for x in inp.flatten()], shape=[4, 2], dtype=dtypes.float32)
@@ -647,7 +647,7 @@ class TileTest(test.TestCase):
dtype=dtypes.float32)
outputs = array_ops.gather(array_ops.tile(inputs, [3]),
[1, 5, 9, 3, 7, 2, 2, 2])
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs, inputs.get_shape().as_list(),
outputs, outputs.get_shape().as_list())
@@ -659,7 +659,7 @@ class TileTest(test.TestCase):
inputs = array_ops.reshape(inputs, [-1, 1, 1])
outputs = array_ops.gather(array_ops.tile(inputs, [3, 4, 2]),
[1, 5, 9, 3, 7, 2, 2, 2])
- with self.test_session():
+ with self.cached_session():
error = gradient_checker.compute_gradient_error(
inputs, inputs.get_shape().as_list(),
outputs, outputs.get_shape().as_list())
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 40d384c623..c08d3222b3 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -107,7 +107,7 @@ class SliceTest(test.TestCase):
def testScalarInput(self):
input_val = 0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test with constant input; shape inference fails.
with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
constant_op.constant(input_val)[:].get_shape()
@@ -121,7 +121,7 @@ class SliceTest(test.TestCase):
def testInvalidIndex(self):
input_val = [1, 2]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Test with constant input; shape inference fails.
with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
constant_op.constant(input_val)[1:, 1:].get_shape()
diff --git a/tensorflow/python/kernel_tests/softmax_op_test.py b/tensorflow/python/kernel_tests/softmax_op_test.py
index fbf1adba9b..89f4697e5c 100644
--- a/tensorflow/python/kernel_tests/softmax_op_test.py
+++ b/tensorflow/python/kernel_tests/softmax_op_test.py
@@ -22,7 +22,6 @@ import unittest
import numpy as np
-from tensorflow.python.compat import compat
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import array_ops
@@ -163,10 +162,9 @@ class SoftmaxTest(test.TestCase):
self._testOverflow(use_gpu=False)
def test1DTensorAsInputNoReshape(self):
- with compat.forward_compatibility_horizon(2018, 8, 27):
- self._testSoftmax(
- np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)
- self._testOverflow(use_gpu=False)
+ self._testSoftmax(
+ np.array([3., 2., 3., 9.]).astype(np.float64), use_gpu=False)
+ self._testOverflow(use_gpu=False)
def test3DTensorAsInput(self):
self._testSoftmax(
@@ -177,13 +175,12 @@ class SoftmaxTest(test.TestCase):
self._testOverflow(use_gpu=False)
def test3DTensorAsInputNoReshape(self):
- with compat.forward_compatibility_horizon(2018, 8, 27):
- self._testSoftmax(
- np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
- [[2., 3., 4., 5.], [6., 7., 8., 9.]],
- [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
- use_gpu=False)
- self._testOverflow(use_gpu=False)
+ self._testSoftmax(
+ np.array([[[1., 1., 1., 1.], [1., 2., 3., 4.]],
+ [[2., 3., 4., 5.], [6., 7., 8., 9.]],
+ [[5., 4., 3., 2.], [1., 2., 3., 4.]]]).astype(np.float32),
+ use_gpu=False)
+ self._testOverflow(use_gpu=False)
def testAlongFirstDimension(self):
self._testSoftmax(
@@ -210,7 +207,7 @@ class SoftmaxTest(test.TestCase):
self.assertEqual([3, 2, 4], op.get_shape())
def testEmptyInput(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=[0, 3])
self.assertEqual(0, array_ops.size(x).eval())
# reshape would raise if logits is empty
@@ -218,7 +215,7 @@ class SoftmaxTest(test.TestCase):
nn_ops.softmax(x, axis=0).eval()
def testDimTooLarge(self):
- with self.test_session():
+ with self.cached_session():
# Use placeholder to make sure we get runtime error instead of shape
# inference error.
dim = array_ops.placeholder_with_default(100, shape=[])
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index c0269db9ae..e8dc272637 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
@@ -72,7 +71,7 @@ class SoftplusTest(test.TestCase):
use_gpu=True)
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -88,7 +87,7 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 1e-4)
def testGradGrad(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -105,7 +104,7 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 5e-5)
def testGradGradGrad(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -123,10 +122,10 @@ class SoftplusTest(test.TestCase):
self.assertLess(err, 5e-5)
def testNoInts(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "No OpKernel was registered to support Op 'Softplus'"):
+ TypeError,
+ "'features' has DataType int32 not in list of allowed values"):
nn_ops.softplus(constant_op.constant(7)).eval()
diff --git a/tensorflow/python/kernel_tests/softsign_op_test.py b/tensorflow/python/kernel_tests/softsign_op_test.py
index a5247ce08d..1b4db9fa46 100644
--- a/tensorflow/python/kernel_tests/softsign_op_test.py
+++ b/tensorflow/python/kernel_tests/softsign_op_test.py
@@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import errors
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
@@ -51,7 +50,7 @@ class SoftsignTest(test.TestCase):
use_gpu=True)
def testGradient(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant(
[-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
shape=[2, 5],
@@ -67,10 +66,10 @@ class SoftsignTest(test.TestCase):
self.assertLess(err, 1e-4)
def testNoInts(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "No OpKernel was registered to support Op 'Softsign'"):
+ TypeError,
+ "'features' has DataType int32 not in list of allowed values"):
nn_ops.softsign(constant_op.constant(7)).eval()
diff --git a/tensorflow/python/kernel_tests/spacetobatch_op_test.py b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
index 2a9232b6ae..e267c05915 100644
--- a/tensorflow/python/kernel_tests/spacetobatch_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
@@ -551,7 +551,7 @@ class SpaceToBatchNDGradientTest(test.TestCase):
def _checkGrad(self, x, block_shape, paddings):
block_shape = np.array(block_shape)
paddings = np.array(paddings).reshape((len(block_shape), 2))
- with self.test_session():
+ with self.cached_session():
tf_x = ops.convert_to_tensor(x)
tf_y = array_ops.space_to_batch_nd(tf_x, block_shape, paddings)
epsilon = 1e-5
@@ -638,7 +638,7 @@ class RequiredSpaceToBatchPaddingsTest(test.TestCase):
t_paddings, t_crops = array_ops.required_space_to_batch_paddings(
input_shape_placeholder, block_shape_placeholder,
base_paddings_placeholder)
- with self.test_session():
+ with self.cached_session():
paddings_result = t_paddings.eval(assignments)
crops_result = t_crops.eval(assignments)
self.assertAllEqual(paddings_result, paddings_const)
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index 3bb5e899fe..a824d5c826 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -99,20 +99,20 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
""", q.accumulator_ref.op.node_def)
def testAccumulatorSizeEmpty(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q")
self.assertEqual(q.num_accumulated().eval(), 0)
def testAccumulatorSetGlobalStep(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
set_global_step_op = q.set_global_step(1)
set_global_step_op.run()
def testAccumulatorApplyGradFloat32(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
accum_op = q.apply_indexed_slices_grad(
@@ -123,7 +123,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertEqual(q.num_accumulated().eval(), 1)
def testDtypes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dtypes = [dtypes_lib.float16, dtypes_lib.float32, dtypes_lib.float64]
for i in range(len(dtypes)):
@@ -145,7 +145,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self._assertEqual_nparray(sum_elems / len(elems), result, sess)
def testAccumulatorMultipleAccumulators(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q_f32_0 = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
q_f32_1 = data_flow_ops.SparseConditionalAccumulator(
@@ -175,7 +175,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self._assertEqual_indexedslices(expected_tensors[i], result)
def testAccumulatorTakeGradMean(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=())
@@ -195,7 +195,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual([-1, 2], val.dense_shape)
def testAccumulatorTakeGradSum(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM")
@@ -220,7 +220,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid")
def testAccumulatorRepeatedTakeGrad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=())
@@ -258,7 +258,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual(val.dense_shape, [-1, 2])
def testParallelApplyGradMean(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
@@ -289,7 +289,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
val, sess)
def testParallelApplyGradSum(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32,
name="Q",
@@ -323,7 +323,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
val, sess)
def testParallelTakeGrad(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
elems = [e + 1 for e in range(10)]
@@ -362,7 +362,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
np.array([[0, 0], [elems[i], 0]]), results[i], sess)
def testAccumulatorApplyAndBlockingTake(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
@@ -397,7 +397,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
sess.run(takeg_op)
def testAccumulatorCancel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32,
name="Q",
@@ -416,7 +416,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
takeg_thread.join()
def testNonVectorIndices(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -428,7 +428,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
grad_values=np.array([1, 2]).astype(np.float32)).run()
def testZeroDimensionValues(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -438,7 +438,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
grad_indices=[0], grad_values=np.array(1).astype(np.float32)).run()
def testWrongNonEmptyInputValues(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -449,7 +449,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
grad_values=np.array([[0, 1, 1]]).astype(np.float32)).run()
def testDynamicNonVectorIndices(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -468,7 +468,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
})
def testDynamicWrongNonEmptyInputValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
@@ -486,7 +486,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
})
def testEmptyShapeApply(self):
- with self.test_session():
+ with self.cached_session():
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([]))
@@ -511,7 +511,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
q.apply_grad(grad_indices=[0], grad_values=[1.0]).run()
def testValidateShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=[2, 2, None])
@@ -606,7 +606,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
local_step=1).run()
def testReturnShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=[2, None])
@@ -631,7 +631,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase):
self.assertAllEqual(val.dense_shape, [-1, 2, 2, 3])
def testApplyGradtInt32IndicesAndShape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([3, 3]))
accum_op = q.apply_grad(
diff --git a/tensorflow/python/kernel_tests/sparse_cross_op_test.py b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
index ca7898d466..6e0714da70 100644
--- a/tensorflow/python/kernel_tests/sparse_cross_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_cross_op_test.py
@@ -42,7 +42,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_dense(self):
@@ -62,7 +62,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_mixed_string_sparse(self):
@@ -76,7 +76,7 @@ class SparseCrossOpTest(test.TestCase):
'333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2', '55555_X_batch2-FC2-F1',
'55555_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_mixed_string_dense(self):
@@ -94,7 +94,7 @@ class SparseCrossOpTest(test.TestCase):
'55555_X_batch2-FC2-F1', '55555_X_batch2-FC2-F2',
'999999_X_batch2-FC2-F1', '999999_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_sparse_cross_dense(self):
@@ -111,7 +111,7 @@ class SparseCrossOpTest(test.TestCase):
'batch2-FC1-F1_X_batch2-FC2-F1', 'batch2-FC1-F1_X_batch2-FC2-F2',
'batch2-FC1-F2_X_batch2-FC2-F1', 'batch2-FC1-F2_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_integer_sparse_input(self):
@@ -127,7 +127,7 @@ class SparseCrossOpTest(test.TestCase):
'333_X_batch2-FC2-F1', '333_X_batch2-FC2-F2',
'5555_X_batch2-FC2-F1', '5555_X_batch2-FC2-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_permutation_3x3x3(self):
@@ -169,7 +169,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F2',
'batch1-FC1-F3_X_batch1-FC2-F3_X_batch1-FC3-F3'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_permutation_3x1x2(self):
@@ -188,7 +188,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F1',
'batch1-FC1-F3_X_batch1-FC2-F1_X_batch1-FC3-F2'
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_large_batch(self):
@@ -221,7 +221,7 @@ class SparseCrossOpTest(test.TestCase):
])
expected_out = self._sparse_tensor(col_out)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_one_column_empty(self):
@@ -234,7 +234,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([], 1),
self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_empty(sess.run(op))
def test_some_columns_empty(self):
@@ -253,7 +253,7 @@ class SparseCrossOpTest(test.TestCase):
'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F1',
'batch1-FC1-F2_X_batch1-FC2-F1_X_batch1-FC3-F2'
]], 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_all_columns_empty(self):
@@ -266,7 +266,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([]),
self._sparse_tensor([])
])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_empty(sess.run(op))
def test_hashed_zero_bucket_no_hash_key(self):
@@ -277,7 +277,7 @@ class SparseCrossOpTest(test.TestCase):
])
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[1971693436396284976]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_zero_bucket(self):
@@ -290,7 +290,7 @@ class SparseCrossOpTest(test.TestCase):
hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[4847552627144134031]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
# TODO(sibyl-Aix6ihai): Add benchmark to compare Hashed vs Non-hashed.
@@ -304,7 +304,7 @@ class SparseCrossOpTest(test.TestCase):
num_buckets=100)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[83]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed_output(self):
@@ -318,7 +318,7 @@ class SparseCrossOpTest(test.TestCase):
hash_key=sparse_ops._DEFAULT_HASH_KEY + 1)
# Check actual hashed output to prevent unintentional hashing changes.
expected_out = self._sparse_tensor([[31]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._assert_sparse_tensor_equals(expected_out, sess.run(op))
def test_hashed__has_no_collision(self):
@@ -344,7 +344,7 @@ class SparseCrossOpTest(test.TestCase):
self._sparse_tensor([['batch1-FC3-F1', 'batch1-FC3-F2']])
],
num_buckets=1000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
out = sess.run(op)
self.assertEqual(6, len(out.values))
self.assertAllEqual([[0, i] for i in range(6)], out.indices)
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
index f50e39d6d5..90009fc33e 100644
--- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -130,7 +130,7 @@ class MatMulGradientTest(test.TestCase):
def _testGradients(self, tr_a, tr_b, sp_a, sp_b, a_dtype, b_dtype, delta,
name):
- with self.test_session():
+ with self.cached_session():
a = constant_op.constant(
RandMatrix(
3, 2, tr_a, round_bfloat=True), dtype=dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py
index fc39de150e..79efee3f5b 100644
--- a/tensorflow/python/kernel_tests/sparse_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_ops_test.py
@@ -628,7 +628,7 @@ class SparseReduceTest(test_util.TensorFlowTestCase):
else:
np_ans = np.max(np_ans, axis=ra, keepdims=keep_dims)
- with self.test_session():
+ with self.cached_session():
if do_sum:
tf_dense_ans = sparse_ops.sparse_reduce_sum(sp_t, reduction_axes,
keep_dims)
diff --git a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
index 96793d5af3..31e84341ae 100644
--- a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
@@ -76,7 +76,7 @@ class SparseTensorsMapTest(test.TestCase):
return sparse_tensor_lib.SparseTensorValue(ind, val, shape)
def testAddTakeMany(self):
- with self.test_session(graph=ops.Graph(), use_gpu=False) as sess:
+ with self.session(graph=ops.Graph(), use_gpu=False) as sess:
sp_input0 = self._SparseTensorValue_5x6(np.arange(6))
sp_input1 = self._SparseTensorValue_3x4(np.arange(6))
handle0 = add_sparse_to_tensors_map(sp_input0, shared_name="a")
diff --git a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
index 87a4eb9c7b..c71746cc99 100644
--- a/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
+++ b/tensorflow/python/kernel_tests/sparse_to_dense_op_py_test.py
@@ -81,7 +81,7 @@ class SparseToDenseTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
def testZeroDefault(self):
- with self.test_session():
+ with self.cached_session():
x = sparse_ops.sparse_to_dense(2, [4], 7).eval()
self.assertAllEqual(x, [0, 0, 7, 0])
@@ -94,12 +94,12 @@ class SparseToDenseTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
def testBadShape(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
_SparseToDense([1, 3], [[5], [3]], 1, -1)
def testBadValue(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense([1, 3], [5], [[5], [3]], -1)
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[2,1\], "
@@ -107,20 +107,20 @@ class SparseToDenseTest(test.TestCase):
dense.eval()
def testBadNumValues(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense([1, 3], [5], [1, 2, 3], -1)
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
dense.eval()
def testBadDefault(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense([1, 3], [5], [1, 2], [0])
with self.assertRaisesOpError("default_value should be a scalar"):
dense.eval()
def testOutOfBoundsIndicesWithWithoutValidation(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense(
sparse_indices=[[1], [10]],
output_size=[5],
@@ -140,7 +140,7 @@ class SparseToDenseTest(test.TestCase):
dense_without_validation.eval()
def testRepeatingIndicesWithWithoutValidation(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense(
sparse_indices=[[1], [1]],
output_size=[5],
@@ -158,7 +158,7 @@ class SparseToDenseTest(test.TestCase):
dense_without_validation.eval()
def testUnsortedIndicesWithWithoutValidation(self):
- with self.test_session():
+ with self.cached_session():
dense = _SparseToDense(
sparse_indices=[[2], [1]],
output_size=[5],
diff --git a/tensorflow/python/kernel_tests/sparsemask_op_test.py b/tensorflow/python/kernel_tests/sparsemask_op_test.py
index cf6c9494ae..6f5dd45b61 100644
--- a/tensorflow/python/kernel_tests/sparsemask_op_test.py
+++ b/tensorflow/python/kernel_tests/sparsemask_op_test.py
@@ -34,7 +34,7 @@ class SparseMaskTest(test.TestCase):
out_values = values[1:, :]
out_indices = np.array([2, 3, 4], dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_tensor = ops.convert_to_tensor(values)
indices_tensor = ops.convert_to_tensor(indices)
mask_indices_tensor = ops.convert_to_tensor(mask_indices)
diff --git a/tensorflow/python/kernel_tests/string_format_op_test.py b/tensorflow/python/kernel_tests/string_format_op_test.py
new file mode 100644
index 0000000000..74a5072bab
--- /dev/null
+++ b/tensorflow/python/kernel_tests/string_format_op_test.py
@@ -0,0 +1,384 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.kernels.logging_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class StringFormatOpTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDim(self):
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.cached_session():
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", [tensor])
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableScalar(self):
+ with self.cached_session():
+ var = variables.Variable(3.34)
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "3.34"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneVariableOneDim(self):
+ with self.cached_session():
+ var = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}", [var])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatTwoVariablesWithAssignAdd(self):
+ with self.cached_session():
+ var_one = variables.Variable(2.14)
+ plus_one = var_one.assign_add(1.0)
+ var_two = variables.Variable(math_ops.range(10))
+ format_output = string_ops.string_format("{}, {}", [var_one, var_two])
+ if not context.executing_eagerly():
+ variables.global_variables_initializer().run()
+ self.evaluate(plus_one)
+ out = self.evaluate(format_output)
+ expected = "3.14, [0 1 2 ... 7 8 9]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimFloat(self):
+ with self.cached_session():
+ tensor = constant_op.constant([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = "[0 0.1 0.2 ... 0.5 0.6 0.7]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimMatchesSummarize(self):
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimVarySummarize(self):
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=-1)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=1)
+ out = self.evaluate(format_output)
+ expected = "[0 ... 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = "[0 1 ... 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ with self.cached_session():
+ tensor = math_ops.range(6)
+ format_output = string_ops.string_format("{}", tensor, summarize=10)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4 5]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorOneDimAlmostSummarize(self):
+ with self.cached_session():
+ tensor = math_ops.range(5)
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = "[0 1 2 3 4]"
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimLessThanSummarize(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(4), [2, 2])
+ format_output = string_ops.string_format("{}", tensor, summarize=3)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1]\n"
+ " [2 3]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDim(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTwoDimSummarizeTwo(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}", tensor, summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorThreeDim(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(1000), [10, 10, 10])
+ format_output = string_ops.string_format("{}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]\n"
+ "\n"
+ " [[100 101 102 ... 107 108 109]\n"
+ " [110 111 112 ... 117 118 119]\n"
+ " [120 121 122 ... 127 128 129]\n"
+ " ...\n [170 171 172 ... 177 178 179]\n"
+ " [180 181 182 ... 187 188 189]\n"
+ " [190 191 192 ... 197 198 199]]\n"
+ "\n"
+ " [[200 201 202 ... 207 208 209]\n"
+ " [210 211 212 ... 217 218 219]\n"
+ " [220 221 222 ... 227 228 229]\n"
+ " ...\n"
+ " [270 271 272 ... 277 278 279]\n"
+ " [280 281 282 ... 287 288 289]\n"
+ " [290 291 292 ... 297 298 299]]\n"
+ "\n"
+ " ...\n"
+ "\n"
+ " [[700 701 702 ... 707 708 709]\n"
+ " [710 711 712 ... 717 718 719]\n"
+ " [720 721 722 ... 727 728 729]\n"
+ " ...\n"
+ " [770 771 772 ... 777 778 779]\n"
+ " [780 781 782 ... 787 788 789]\n"
+ " [790 791 792 ... 797 798 799]]\n"
+ "\n"
+ " [[800 801 802 ... 807 808 809]\n"
+ " [810 811 812 ... 817 818 819]\n"
+ " [820 821 822 ... 827 828 829]\n"
+ " ...\n"
+ " [870 871 872 ... 877 878 879]\n"
+ " [880 881 882 ... 887 888 889]\n"
+ " [890 891 892 ... 897 898 899]]\n"
+ "\n"
+ " [[900 901 902 ... 907 908 909]\n"
+ " [910 911 912 ... 917 918 919]\n"
+ " [920 921 922 ... 927 928 929]\n"
+ " ...\n"
+ " [970 971 972 ... 977 978 979]\n"
+ " [980 981 982 ... 987 988 989]\n"
+ " [990 991 992 ... 997 998 999]]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefix(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplatePrefixAndSuffix(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}, suffix",
+ tensor)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatOneTensorTemplateSuffix(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("{}, suffix", tensor)
+ out = self.evaluate(format_output)
+ expected = ("[[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], suffix")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatNoTensor(self):
+ with self.cached_session():
+ format_output = string_ops.string_format("No tensor.", ())
+ out = self.evaluate(format_output)
+ expected = "No tensor."
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatMultiTensor(self):
+ with self.cached_session():
+ tensor_one = array_ops.reshape(math_ops.range(100), [10, 10])
+ tensor_two = tensor_one * 10
+ format_output = string_ops.string_format("One: {},\nTwo: {}",
+ (tensor_one, tensor_two))
+ out = self.evaluate(format_output)
+ expected = ("One: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]],\n"
+ "Two: [[0 10 20 ... 70 80 90]\n"
+ " [100 110 120 ... 170 180 190]\n"
+ " [200 210 220 ... 270 280 290]\n"
+ " ...\n"
+ " [700 710 720 ... 770 780 790]\n"
+ " [800 810 820 ... 870 880 890]\n"
+ " [900 910 920 ... 970 980 990]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeOne(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=1)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 ... 9]\n"
+ " ...\n"
+ " [90 ... 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatSummarizeTwo(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: {}", tensor,
+ summarize=2)
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 ... 8 9]\n"
+ " [10 11 ... 18 19]\n"
+ " ...\n"
+ " [80 81 ... 88 89]\n"
+ " [90 91 ... 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFormatPlaceholder(self):
+ with self.cached_session():
+ tensor = array_ops.reshape(math_ops.range(100), [10, 10])
+ format_output = string_ops.string_format("tensor summary: %t%", tensor,
+ placeholder="%t%")
+ out = self.evaluate(format_output)
+ expected = ("tensor summary: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]]")
+ self.assertEqual(compat.as_text(out), expected)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testTensorCountMustMatchPlaceholderCount(self):
+ with self.cached_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", tensor)
+ self.evaluate(format_output)
+ with self.cached_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"2 placeholder\(s\) in template does not match 1 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{} {}", [tensor])
+ self.evaluate(format_output)
+ with self.cached_session():
+ with self.assertRaisesRegexp(
+ ValueError, r"1 placeholder\(s\) in template does not match 2 "
+ r"tensor\(s\) provided as input"):
+ tensor = math_ops.range(10)
+ format_output = string_ops.string_format("{}", (tensor, tensor))
+ self.evaluate(format_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/string_join_op_test.py b/tensorflow/python/kernel_tests/string_join_op_test.py
index ce19333654..e4371ab5b9 100644
--- a/tensorflow/python/kernel_tests/string_join_op_test.py
+++ b/tensorflow/python/kernel_tests/string_join_op_test.py
@@ -28,7 +28,7 @@ class StringJoinOpTest(test.TestCase):
input1 = "a"
input2 = [["b"], ["c"]]
- with self.test_session():
+ with self.cached_session():
output = string_ops.string_join([input0, input1])
self.assertAllEqual(output.eval(), [b"aa", b"ba"])
diff --git a/tensorflow/python/kernel_tests/string_length_op_test.py b/tensorflow/python/kernel_tests/string_length_op_test.py
index 075a3204ad..4afe3ad3f4 100644
--- a/tensorflow/python/kernel_tests/string_length_op_test.py
+++ b/tensorflow/python/kernel_tests/string_length_op_test.py
@@ -27,11 +27,38 @@ class StringLengthOpTest(test.TestCase):
def testStringLength(self):
strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
lengths = string_ops.string_length(strings)
values = sess.run(lengths)
self.assertAllEqual(values, [[[1, 2], [3, 4], [5, 6]]])
+ def testUnit(self):
+ unicode_strings = [u"H\xc3llo", u"\U0001f604"]
+ utf8_strings = [s.encode("utf-8") for s in unicode_strings]
+ expected_utf8_byte_lengths = [6, 4]
+ expected_utf8_char_lengths = [5, 1]
+
+ with self.test_session() as sess:
+ utf8_byte_lengths = string_ops.string_length(utf8_strings, unit="BYTE")
+ utf8_char_lengths = string_ops.string_length(
+ utf8_strings, unit="UTF8_CHAR")
+ self.assertAllEqual(
+ sess.run(utf8_byte_lengths), expected_utf8_byte_lengths)
+ self.assertAllEqual(
+ sess.run(utf8_char_lengths), expected_utf8_char_lengths)
+ with self.assertRaisesRegexp(
+ ValueError, "Attr 'unit' of 'StringLength' Op passed string 'XYZ' "
+ 'not in: "BYTE", "UTF8_CHAR"'):
+ string_ops.string_length(utf8_strings, unit="XYZ")
+
+ def testLegacyPositionalName(self):
+ # Code that predates the 'unit' parameter may have used a positional
+ # argument for the 'name' parameter. Check that we don't break such code.
+ strings = [[["1", "12"], ["123", "1234"], ["12345", "123456"]]]
+ lengths = string_ops.string_length(strings, "some_name")
+ with self.test_session():
+ self.assertAllEqual(lengths.eval(), [[[1, 2], [3, 4], [5, 6]]])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/string_split_op_test.py b/tensorflow/python/kernel_tests/string_split_op_test.py
index b6a0f45adc..b968e885ed 100644
--- a/tensorflow/python/kernel_tests/string_split_op_test.py
+++ b/tensorflow/python/kernel_tests/string_split_op_test.py
@@ -32,7 +32,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplit(self):
strings = ["pigs on the wing", "animals"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
@@ -42,7 +42,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitEmptyDelimiter(self):
strings = ["hello", "hola", b"\xF0\x9F\x98\x8E"] # Last string is U+1F60E
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings, delimiter="")
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4],
@@ -60,7 +60,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitEmptyToken(self):
strings = ["", " a", "b ", " c", " ", " d ", " e", "f ", " g ", " "]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(
@@ -72,7 +72,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitOnSetEmptyToken(self):
strings = ["", " a", "b ", " c", " ", " d ", ". e", "f .", " .g. ", " ."]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings, delimiter=" .")
indices, values, shape = sess.run(tokens)
self.assertAllEqual(
@@ -84,7 +84,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitWithDelimiter(self):
strings = ["hello|world", "hello world"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertRaises(
ValueError, string_ops.string_split, strings, delimiter=["|", ""])
@@ -106,7 +106,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitWithDelimiterTensor(self):
strings = ["hello|world", "hello world"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
delimiter = array_ops.placeholder(dtypes.string)
tokens = string_ops.string_split(strings, delimiter=delimiter)
@@ -124,7 +124,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitWithDelimitersTensor(self):
strings = ["hello.cruel,world", "hello cruel world"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
delimiter = array_ops.placeholder(dtypes.string)
tokens = string_ops.string_split(strings, delimiter=delimiter)
@@ -143,7 +143,7 @@ class StringSplitOpTest(test.TestCase):
def testStringSplitWithNoSkipEmpty(self):
strings = ["#a", "b#", "#c#"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings, "#", skip_empty=False)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1],
@@ -152,7 +152,7 @@ class StringSplitOpTest(test.TestCase):
self.assertAllEqual(values, [b"", b"a", b"b", b"", b"", b"c", b""])
self.assertAllEqual(shape, [3, 3])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split(strings, "#")
indices, values, shape = sess.run(tokens)
self.assertAllEqual(values, [b"a", b"b", b"c"])
@@ -165,7 +165,7 @@ class StringSplitV2OpTest(test.TestCase):
def testSplitV2(self):
strings = ["pigs on the wing", "animals"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
@@ -180,7 +180,7 @@ class StringSplitV2OpTest(test.TestCase):
# ['', '', '4', '5', '', '6', '']
strings = ["1<>2<>3", "<><>4<>5<><>6<>"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep="<>")
indices, values, shape = sess.run(tokens)
self.assertAllEqual(
@@ -198,7 +198,7 @@ class StringSplitV2OpTest(test.TestCase):
# ['1', '2', '', '3', '']
strings = ["1,2,3", "4,5,,6,"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep=',')
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
@@ -215,7 +215,7 @@ class StringSplitV2OpTest(test.TestCase):
#['1', '2', '3']
strings = ["1 2 3", " 4 5 6 "]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2],
@@ -231,7 +231,7 @@ class StringSplitV2OpTest(test.TestCase):
# ['4', '5,,6,']
strings = ["1,2,3", "4,5,,6,"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, sep=',', maxsplit=1)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1],
@@ -247,7 +247,7 @@ class StringSplitV2OpTest(test.TestCase):
# ['4', '5 6 ']
strings = ["1 2 3", " 4 5 6 "]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tokens = string_ops.string_split_v2(strings, maxsplit=1)
indices, values, shape = sess.run(tokens)
self.assertAllEqual(indices, [[0, 0], [0, 1],
diff --git a/tensorflow/python/kernel_tests/string_strip_op_test.py b/tensorflow/python/kernel_tests/string_strip_op_test.py
index 30fd477ff4..a96b71490e 100644
--- a/tensorflow/python/kernel_tests/string_strip_op_test.py
+++ b/tensorflow/python/kernel_tests/string_strip_op_test.py
@@ -28,7 +28,7 @@ class StringStripOpTest(test.TestCase):
def test_string_strip(self):
strings = ["pigs on the wing", "animals"]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = string_ops.string_strip(strings)
output = sess.run(output)
self.assertAllEqual(output, [b"pigs on the wing", b"animals"])
@@ -37,7 +37,7 @@ class StringStripOpTest(test.TestCase):
strings = [["pigs on the wing", "animals"],
[" hello ", "\n\tworld \r \n"]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = string_ops.string_strip(strings)
output = sess.run(output)
self.assertAllEqual(output, [[b"pigs on the wing", b"animals"],
@@ -46,7 +46,7 @@ class StringStripOpTest(test.TestCase):
def test_string_strip_with_empty_strings(self):
strings = [" hello ", "", "world ", " \t \r \n "]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = string_ops.string_strip(strings)
output = sess.run(output)
self.assertAllEqual(output, [b"hello", b"", b"world", b""])
diff --git a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
index 2c6064e64b..9cb0c9d18f 100644
--- a/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_hash_bucket_op_test.py
@@ -27,7 +27,7 @@ from tensorflow.python.platform import test
class StringToHashBucketOpTest(test.TestCase):
def testStringToOneHashBucketFast(self):
- with self.test_session():
+ with self.cached_session():
input_string = array_ops.placeholder(dtypes.string)
output = string_ops.string_to_hash_bucket_fast(input_string, 1)
result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -35,7 +35,7 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([0, 0, 0], result)
def testStringToHashBucketsFast(self):
- with self.test_session():
+ with self.cached_session():
input_string = array_ops.placeholder(dtypes.string)
output = string_ops.string_to_hash_bucket_fast(input_string, 10)
result = output.eval(feed_dict={input_string: ['a', 'b', 'c', 'd']})
@@ -47,7 +47,7 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([9, 2, 2, 5], result)
def testStringToOneHashBucketLegacyHash(self):
- with self.test_session():
+ with self.cached_session():
input_string = array_ops.placeholder(dtypes.string)
output = string_ops.string_to_hash_bucket(input_string, 1)
result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -55,7 +55,7 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([0, 0, 0], result)
def testStringToHashBucketsLegacyHash(self):
- with self.test_session():
+ with self.cached_session():
input_string = array_ops.placeholder(dtypes.string)
output = string_ops.string_to_hash_bucket(input_string, 10)
result = output.eval(feed_dict={input_string: ['a', 'b', 'c']})
@@ -66,14 +66,14 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([8, 0, 7], result)
def testStringToOneHashBucketStrongOneHashBucket(self):
- with self.test_session():
+ with self.cached_session():
input_string = constant_op.constant(['a', 'b', 'c'])
output = string_ops.string_to_hash_bucket_strong(
input_string, 1, key=[123, 345])
self.assertAllEqual([0, 0, 0], output.eval())
def testStringToHashBucketsStrong(self):
- with self.test_session():
+ with self.cached_session():
input_string = constant_op.constant(['a', 'b', 'c'])
output = string_ops.string_to_hash_bucket_strong(
input_string, 10, key=[98765, 132])
@@ -84,7 +84,7 @@ class StringToHashBucketOpTest(test.TestCase):
self.assertAllEqual([4, 2, 8], output.eval())
def testStringToHashBucketsStrongInvalidKey(self):
- with self.test_session():
+ with self.cached_session():
input_string = constant_op.constant(['a', 'b', 'c'])
with self.assertRaisesOpError('Key must have 2 elements'):
string_ops.string_to_hash_bucket_strong(
diff --git a/tensorflow/python/kernel_tests/string_to_number_op_test.py b/tensorflow/python/kernel_tests/string_to_number_op_test.py
index cc4c21b66c..99ee25e125 100644
--- a/tensorflow/python/kernel_tests/string_to_number_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_number_op_test.py
@@ -29,7 +29,7 @@ _ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: "
class StringToNumberOpTest(test.TestCase):
def _test(self, tf_type, good_pairs, bad_pairs):
- with self.test_session():
+ with self.cached_session():
# Build a small testing graph.
input_string = array_ops.placeholder(dtypes.string)
output = parsing_ops.string_to_number(
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index 73ac71e1f5..cd3fe14883 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import errors_impl
@@ -25,7 +26,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class SubstrOpTest(test.TestCase):
+class SubstrOpTest(test.TestCase, parameterized.TestCase):
def _testScalarString(self, dtype):
test_string = b"Hello"
@@ -34,18 +35,40 @@ class SubstrOpTest(test.TestCase):
expected_value = b"ell"
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # position is equal to the length of string.
+ # Negative position.
+ test_string = b"Hello"
+ position = np.array(-4, dtype)
+ length = np.array(3, dtype)
+ expected_value = b"ell"
+
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
+ substr = substr_op.eval()
+ self.assertAllEqual(substr, expected_value)
+
+ # Position is equal to the length of string.
test_string = b""
position = np.array(0, dtype)
length = np.array(2, dtype)
expected_value = b""
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
+ substr = substr_op.eval()
+ self.assertAllEqual(substr, expected_value)
+
+ # Negative position magnitude is equal to the length of string.
+ test_string = b"yo"
+ position = np.array(-2, dtype)
+ length = np.array(1, dtype)
+ expected_value = b"y"
+
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -56,7 +79,18 @@ class SubstrOpTest(test.TestCase):
expected_value = [b"ell", b"orl"]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
+ substr = substr_op.eval()
+ self.assertAllEqual(substr, expected_value)
+
+ # Negative position.
+ test_string = [b"Hello", b"World"]
+ position = np.array(-4, dtype)
+ length = np.array(3, dtype)
+ expected_value = [b"ell", b"orl"]
+
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -70,7 +104,21 @@ class SubstrOpTest(test.TestCase):
[b"ixte", b"even", b"ight"]]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
+ substr = substr_op.eval()
+ self.assertAllEqual(substr, expected_value)
+
+ # Negative position
+ test_string = [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]]
+ position = np.array(-2, dtype)
+ length = np.array(2, dtype)
+ expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"],
+ [b"en", b"en", b"en"]]
+
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -78,13 +126,13 @@ class SubstrOpTest(test.TestCase):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
- length = np.array([[2, 3, 4], [4, 3, 2], [5, 5, 5]], dtype)
- expected_value = [[b"en", b"eve", b"lve"], [b"hirt", b"urt", b"te"],
- [b"ixtee", b"vente", b"hteen"]]
+ position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
+ length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
+ expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
+ [b"xteen", b"vente", b"hteen"]]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -94,33 +142,33 @@ class SubstrOpTest(test.TestCase):
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"],
[b"nineteen", b"twenty", b"twentyone"]]
- position = np.array([1, 2, 3], dtype)
+ position = np.array([1, -4, 3], dtype)
length = np.array([1, 2, 3], dtype)
- expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
- [b"i", b"ve", b"hte"], [b"i", b"en", b"nty"]]
+ expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
+ [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Broadcast input string onto pos/len
test_string = [b"thirteen", b"fourteen", b"fifteen"]
- position = np.array([[1, 2, 3], [3, 2, 1], [5, 5, 5]], dtype)
+ position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- expected_value = [[b"hir", b"ur", b"t"], [b"r", b"ur", b"ift"],
- [b"ee", b"ee", b"en"]]
+ expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
+ [b"ee", b"ee", b"ft"]]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Test 1D broadcast
test_string = b"thirteen"
- position = np.array([1, 5, 7], dtype)
+ position = np.array([1, -5, 7], dtype)
length = np.array([3, 2, 1], dtype)
- expected_value = [b"hir", b"ee", b"n"]
+ expected_value = [b"hir", b"rt", b"n"]
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
@@ -128,10 +176,8 @@ class SubstrOpTest(test.TestCase):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array([1, 2, 3, 4], dtype)
+ position = np.array([1, 2, -3, 4], dtype)
length = np.array([1, 2, 3, 4], dtype)
- expected_value = [[b"e", b"ev", b"lve"], [b"h", b"ur", b"tee"],
- [b"i", b"ve", b"hte"]]
with self.assertRaises(ValueError):
substr_op = string_ops.substr(test_string, position, length)
@@ -141,7 +187,16 @@ class SubstrOpTest(test.TestCase):
position = np.array(7, dtype)
length = np.array(3, dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ substr = substr_op.eval()
+
+ # Scalar/Scalar (with negative)
+ test_string = b"Hello"
+ position = np.array(-7, dtype)
+ length = np.array(3, dtype)
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -150,16 +205,16 @@ class SubstrOpTest(test.TestCase):
position = np.array(4, dtype)
length = np.array(1, dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
- # Negative pos
- test_string = b"Hello"
- position = np.array(-1, dtype)
- length = np.array(3, dtype)
+ # Vector/Scalar (with negative)
+ test_string = [b"good", b"good", b"bad", b"good"]
+ position = np.array(-4, dtype)
+ length = np.array(1, dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -169,7 +224,17 @@ class SubstrOpTest(test.TestCase):
position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ substr = substr_op.eval()
+
+ # Matrix/Matrix (with negative)
+ test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
+ [b"good", b"good", b"good"]]
+ position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
+ length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -178,7 +243,16 @@ class SubstrOpTest(test.TestCase):
position = np.array([1, 2, 4], dtype)
length = np.array([1, 2, 3], dtype)
substr_op = string_ops.substr(test_string, position, length)
- with self.test_session():
+ with self.cached_session():
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ substr = substr_op.eval()
+
+ # Broadcast (with negative)
+ test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
+ position = np.array([-1, -2, -4], dtype)
+ length = np.array([1, 2, 3], dtype)
+ substr_op = string_ops.substr(test_string, position, length)
+ with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
substr = substr_op.eval()
@@ -198,7 +272,18 @@ class SubstrOpTest(test.TestCase):
with self.assertRaises(ValueError):
substr_op = string_ops.substr(test_string, position, length)
- def _testAll(self, dtype):
+ # Negative position.
+ test_string = [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]]
+ position = np.array([[-1, -2, -3]], dtype)
+ length = np.array([1, 2, 3], dtype)
+ # Should fail: position/length have different rank
+ with self.assertRaises(ValueError):
+ substr_op = string_ops.substr(test_string, position, length)
+
+ @parameterized.parameters(np.int32, np.int64)
+ def testAll(self, dtype):
self._testScalarString(dtype)
self._testVectorStrings(dtype)
self._testMatrixStrings(dtype)
@@ -208,14 +293,8 @@ class SubstrOpTest(test.TestCase):
self._testOutOfRangeError(dtype)
self._testMismatchPosLenShapes(dtype)
- def testInt32(self):
- self._testAll(np.int32)
-
- def testInt64(self):
- self._testAll(np.int64)
-
def testWrongDtype(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
string_ops.substr(b"test", 3.0, 1)
with self.assertRaises(TypeError):
diff --git a/tensorflow/python/kernel_tests/summary_audio_op_test.py b/tensorflow/python/kernel_tests/summary_audio_op_test.py
index eaae671192..e59a2ceef7 100644
--- a/tensorflow/python/kernel_tests/summary_audio_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_audio_op_test.py
@@ -50,7 +50,7 @@ class SummaryAudioOpTest(test.TestCase):
def testAudioSummary(self):
np.random.seed(7)
for channels in (1, 2, 5, 8):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
num_frames = 7
shape = (4, num_frames, channels)
# Generate random audio in the range [-1.0, 1.0).
diff --git a/tensorflow/python/kernel_tests/summary_image_op_test.py b/tensorflow/python/kernel_tests/summary_image_op_test.py
index 4718827e88..b650e10404 100644
--- a/tensorflow/python/kernel_tests/summary_image_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_image_op_test.py
@@ -52,7 +52,7 @@ class SummaryImageOpTest(test.TestCase):
def testImageSummary(self):
for depth in (1, 3, 4):
for positive in False, True:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
shape = (4, 5, 7) + (depth,)
bad_color = [255, 0, 0, 255][:depth]
# Build a mostly random image with one nan
@@ -87,7 +87,7 @@ class SummaryImageOpTest(test.TestCase):
def testImageSummaryUint8(self):
np.random.seed(7)
for depth in (1, 3, 4):
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
shape = (4, 5, 7) + (depth,)
# Build a random uint8 image
diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py
index 2da7107f61..0c500120b0 100644
--- a/tensorflow/python/kernel_tests/summary_ops_test.py
+++ b/tensorflow/python/kernel_tests/summary_ops_test.py
@@ -34,7 +34,7 @@ class SummaryOpsTest(test.TestCase):
return summ
def testScalarSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant([10.0, 20.0])
summ = logging_ops.scalar_summary(["c1", "c2"], const, name="mysumm")
value = sess.run(summ)
@@ -45,7 +45,7 @@ class SummaryOpsTest(test.TestCase):
""", self._AsSummary(value))
def testScalarSummaryDefaultName(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant([10.0, 20.0])
summ = logging_ops.scalar_summary(["c1", "c2"], const)
value = sess.run(summ)
@@ -56,7 +56,7 @@ class SummaryOpsTest(test.TestCase):
""", self._AsSummary(value))
def testMergeSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(10.0)
summ1 = summary.histogram("h", const)
summ2 = logging_ops.scalar_summary("c", const)
diff --git a/tensorflow/python/kernel_tests/summary_tensor_op_test.py b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
index d534aadb79..0f4643393a 100644
--- a/tensorflow/python/kernel_tests/summary_tensor_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_tensor_op_test.py
@@ -42,7 +42,7 @@ class SummaryOpsTest(test.TestCase):
self.assertTrue(np.array_equal(actual, expected))
def testTags(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c = constant_op.constant(1)
s1 = summary_ops.tensor_summary("s1", c)
with ops.name_scope("foo"):
@@ -65,7 +65,7 @@ class SummaryOpsTest(test.TestCase):
self.assertEqual(v4.tag, "foo/zod/TensorSummary")
def testScalarSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(10.0)
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -76,7 +76,7 @@ class SummaryOpsTest(test.TestCase):
def testStringSummary(self):
s = six.b("foobar")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(s)
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -86,7 +86,7 @@ class SummaryOpsTest(test.TestCase):
self._AssertNumpyEq(n, s)
def testManyScalarSummary(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = array_ops.ones([5, 5, 5])
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -96,7 +96,7 @@ class SummaryOpsTest(test.TestCase):
def testManyStringSummary(self):
strings = [[six.b("foo bar"), six.b("baz")], [six.b("zoink"), six.b("zod")]]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(strings)
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -106,7 +106,7 @@ class SummaryOpsTest(test.TestCase):
def testManyBools(self):
bools = [True, True, True, False, False, False]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
const = constant_op.constant(bools)
summ = summary_ops.tensor_summary("foo", const)
result = sess.run(summ)
@@ -116,7 +116,7 @@ class SummaryOpsTest(test.TestCase):
self._AssertNumpyEq(n, bools)
def testSummaryDescriptionAndDisplayName(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def get_description(summary_op):
summ_str = sess.run(summary_op)
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 6de6fbe767..0ad2063558 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -1504,6 +1504,19 @@ class TensorArrayTest(test.TestCase):
vdx, vdy = sess.run([dx, dy])
self.assertAllClose(vdx, vdy)
+ def testTensorArrayInt64GPU(self):
+ if not test.is_gpu_available():
+ return
+ with self.test_session(use_gpu=True, force_gpu=True) as sess:
+ value = array_ops.placeholder(dtypes.int64)
+ ta = tensor_array_ops.TensorArray(dtype=dtypes.int64, size=2)
+ ta = ta.scatter([0, 1], value)
+ r0 = ta.read(0)
+ r1 = ta.read(1)
+ v0, v1 = sess.run([r0, r1], feed_dict={value: [-3, 100]})
+ self.assertAllEqual(v0, -3)
+ self.assertAllEqual(v1, 100)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index 8ad29afd0a..d8d76440f1 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -48,7 +48,7 @@ class TensordotTest(test_lib.TestCase):
with self.assertRaises(ValueError):
math_ops.tensordot(a, b, (a_axes, b_axes))
# Invalid dynamic shapes.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Matrix size-incompatible"):
a_ph = array_ops.placeholder(dtypes.float32)
@@ -80,7 +80,7 @@ class TensordotTest(test_lib.TestCase):
output = math_ops.tensordot(a_ph, b_ph, axes_ph)
# Note: We don't support scalar Tensor values for axes.
for axes_value in 1, [1], [0, 1], [[1]], [[0, 1]], [[0], [7]]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
_ = sess.run(
[output], feed_dict={
@@ -92,7 +92,7 @@ class TensordotTest(test_lib.TestCase):
# Test case for 11950
def test_valid_axis(self):
for axes_value in [1, 2], [[1], [2]], [[], []], 0:
- with self.test_session() as sess:
+ with self.cached_session():
np_a = np.ones((3, 3))
np_b = np.array([2, 3, 1])[None, None]
np_ans = np.tensordot(np_a, np_b, axes_value)
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index 290200ce45..f42800226e 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -451,13 +451,13 @@ class TransposeTest(test.TestCase):
array_ops.transpose(array_ops.placeholder(dtypes.int32)).get_shape())
def testNullTensor(self):
- with self.test_session():
+ with self.cached_session():
x = constant_op.constant([], dtype=dtypes.float32, shape=[1, 4, 0])
xt = array_ops.transpose(x, [0, 2, 1]).eval()
self.assertAllEqual(xt.shape, (1, 0, 4))
def _testError(self, x, p, err):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError(err):
array_ops.transpose(x, p).eval()
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index bbc040dc13..316570e13e 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -30,7 +30,7 @@ class UniqueTest(test.TestCase):
def testInt32(self):
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx = array_ops.unique(x)
tf_y, tf_idx = sess.run([y, idx])
@@ -41,7 +41,7 @@ class UniqueTest(test.TestCase):
def testInt32OutIdxInt64(self):
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx = array_ops.unique(x, out_idx=dtypes.int64)
tf_y, tf_idx = sess.run([y, idx])
@@ -53,7 +53,7 @@ class UniqueTest(test.TestCase):
def testString(self):
indx = np.random.randint(65, high=122, size=7000)
x = [chr(i) for i in indx]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx = array_ops.unique(x)
tf_y, tf_idx = sess.run([y, idx])
@@ -65,7 +65,7 @@ class UniqueTest(test.TestCase):
def testInt32Axis(self):
for dtype in [np.int32, np.int64]:
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y0, idx0 = gen_array_ops.unique_v2(x, axis=np.array([0], dtype))
tf_y0, tf_idx0 = sess.run([y0, idx0])
y1, idx1 = gen_array_ops.unique_v2(x, axis=np.array([1], dtype))
@@ -79,7 +79,7 @@ class UniqueTest(test.TestCase):
# This test is only temporary, once V2 is used
# by default, the axis will be wrapped to allow `axis=None`.
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx = gen_array_ops.unique_v2(x, axis=np.array([], np.int32))
tf_y, tf_idx = sess.run([y, idx])
@@ -93,7 +93,7 @@ class UniqueWithCountsTest(test.TestCase):
def testInt32(self):
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x)
tf_y, tf_idx, tf_count = sess.run([y, idx, count])
@@ -106,7 +106,7 @@ class UniqueWithCountsTest(test.TestCase):
def testInt32OutIdxInt64(self):
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x, out_idx=dtypes.int64)
tf_y, tf_idx, tf_count = sess.run([y, idx, count])
@@ -121,7 +121,7 @@ class UniqueWithCountsTest(test.TestCase):
indx = np.random.randint(65, high=122, size=7000)
x = [chr(i) for i in indx]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x)
tf_y, tf_idx, tf_count = sess.run([y, idx, count])
@@ -136,7 +136,7 @@ class UniqueWithCountsTest(test.TestCase):
def testInt32Axis(self):
for dtype in [np.int32, np.int64]:
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y0, idx0, count0 = gen_array_ops.unique_with_counts_v2(
x, axis=np.array([0], dtype))
tf_y0, tf_idx0, tf_count0 = sess.run([y0, idx0, count0])
@@ -154,7 +154,7 @@ class UniqueWithCountsTest(test.TestCase):
# This test is only temporary, once V2 is used
# by default, the axis will be wrapped to allow `axis=None`.
x = np.random.randint(2, high=10, size=7000)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y, idx, count = gen_array_ops.unique_with_counts_v2(
x, axis=np.array([], np.int32))
tf_y, tf_idx, tf_count = sess.run([y, idx, count])
diff --git a/tensorflow/python/kernel_tests/unstack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py
index 1ee6e0866a..b373c419b6 100644
--- a/tensorflow/python/kernel_tests/unstack_op_test.py
+++ b/tensorflow/python/kernel_tests/unstack_op_test.py
@@ -99,7 +99,7 @@ class UnstackOpTest(test.TestCase):
self.assertLess(err, 1e-6)
def testInferNum(self):
- with self.test_session():
+ with self.cached_session():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
x = array_ops.placeholder(np.float32, shape=shape)
cs = array_ops.unstack(x)
@@ -131,13 +131,13 @@ class UnstackOpTest(test.TestCase):
for j in range(-i, i):
expected = np_split_squeeze(a, j)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_unstack = sess.run(array_ops.unstack(a, axis=j))
self.assertAllEqual(expected, actual_unstack)
def testAxis0Default(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
unstacked = sess.run(array_ops.unstack(a))
@@ -156,7 +156,7 @@ class UnstackOpTest(test.TestCase):
array_ops.unstack(a, axis=-3)
def testZeroLengthDim(self):
- with self.test_session():
+ with self.cached_session():
x = array_ops.zeros(shape=(0, 1, 2))
y = array_ops.unstack(x, axis=1)[0].eval()
self.assertEqual(y.shape, (0, 2))
diff --git a/tensorflow/python/kernel_tests/variable_ops_test.py b/tensorflow/python/kernel_tests/variable_ops_test.py
index cf369c0718..3d2f8b6155 100644
--- a/tensorflow/python/kernel_tests/variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/variable_ops_test.py
@@ -118,7 +118,7 @@ class VariableOpTest(test.TestCase):
self.assertEqual(tensor_shape.unknown_shape(), assigned.get_shape())
def testAssignNoShape(self):
- with self.test_session():
+ with self.cached_session():
value = self._NewShapelessTensor()
var = state_ops.variable_op([1, 2], dtypes.float32, set_shape=False)
self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
@@ -126,7 +126,7 @@ class VariableOpTest(test.TestCase):
state_ops.assign(var, value).get_shape())
def testAssignNoShapeNoValidateShape(self):
- with self.test_session():
+ with self.cached_session():
value = self._NewShapelessTensor()
var = state_ops.variable_op([1, 2], dtypes.float32, set_shape=False)
self.assertEqual(tensor_shape.unknown_shape(), var.get_shape())
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index d57b79cb90..401e1ae102 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -113,7 +113,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(w.constraint, constraint)
def testStringDefaultInitializer(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable("string", shape=[], dtype=dtypes.string)
variables_lib.global_variables_initializer().run()
self.assertAllEqual(compat.as_bytes(v.eval()), b"")
@@ -263,7 +263,7 @@ class VariableScopeTest(test.TestCase):
# TODO(alive): support variable partitioning/caching in eager mode.
def testVarScopeCachingDevice(self):
- with self.test_session():
+ with self.cached_session():
caching_device = "/job:moo"
with variable_scope.variable_scope("tower"):
with variable_scope.variable_scope(
@@ -367,7 +367,7 @@ class VariableScopeTest(test.TestCase):
variable_scope.get_variable("s", initializer=init, dtype=dtypes.float64)
def testControlDeps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variable_scope.get_variable(
"v0", [1], initializer=init_ops.constant_initializer(0))
with ops.control_dependencies([v0.value()]):
@@ -403,7 +403,7 @@ class VariableScopeTest(test.TestCase):
variable_scope._DEFAULT_USE_RESOURCE = old
def testControlFlow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variable_scope.get_variable(
"v0", [], initializer=init_ops.constant_initializer(0))
var_dict = {}
@@ -513,7 +513,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "testVarScopeNameScope3/scope2/")
def testVarScopeOriginalNameScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("scope1"):
with variable_scope.variable_scope("tower") as tower:
self.assertEqual(tower.original_name_scope, "scope1/tower/")
@@ -536,7 +536,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc3, "scope1/tower/bar_1/")
def testVarScopeObjectReuse(self):
- with self.test_session():
+ with self.cached_session():
vs = None
with variable_scope.variable_scope("jump", reuse=True) as scope:
vs = scope
@@ -563,7 +563,7 @@ class VariableScopeTest(test.TestCase):
self.assertFalse(jump_no_reuse.reuse)
def testVarScopeGetOrCreateReuse(self):
- with self.test_session():
+ with self.cached_session():
def test_value(value):
x = constant_op.constant(value)
@@ -582,7 +582,7 @@ class VariableScopeTest(test.TestCase):
test_value(17.)
def testVarOpScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("testVarOpScope1"):
with variable_scope.variable_scope("tower", "default", []):
self.assertEqual(
@@ -608,7 +608,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "testVarOpScope2/default_1/testVarOpScope2/")
def testVarOpScopeUniqueNamesInterleavedSubstringScopes(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(None, "defaultScope1"):
with variable_scope.variable_scope(None, "layer"):
self.assertEqual(
@@ -631,7 +631,7 @@ class VariableScopeTest(test.TestCase):
"defaultScope1_2/layer/w:0")
def testVarOpScopeUniqueNamesWithJump(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("default") as default:
with variable_scope.variable_scope(None, "layer"):
self.assertEqual(
@@ -647,7 +647,7 @@ class VariableScopeTest(test.TestCase):
variable_scope.get_variable("w", []).name, "default/layer_2/w:0")
def testVarOpScopeReuse(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
with variable_scope.variable_scope("tower", "default", []):
self.assertEqual(
@@ -673,7 +673,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_1/default/scope2/")
def testVarScopeGetVar(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("root"):
with variable_scope.variable_scope("towerA") as tower_a:
va = variable_scope.get_variable("v", [1])
@@ -719,7 +719,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual("dtype" in str(exc.exception), True)
def testVarScopeOuterScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
pass
with variable_scope.variable_scope(outer):
@@ -743,7 +743,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_2/default/scope2/")
def testVarScopeNestedOuterScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
with variable_scope.variable_scope(outer):
self.assertEqual(
@@ -768,7 +768,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer/default_1/scope2/")
def testVarOpScopeReuseParam(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
with variable_scope.variable_scope("tower", "default", []):
self.assertEqual(
@@ -795,14 +795,14 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_1/default/scope2/")
def testVarOpScopeReuseError(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
with variable_scope.variable_scope(None, "default", reuse=True):
self.assertEqual(
variable_scope.get_variable("w", []).name, "outer/tower/w:0")
def testVarOpScopeOuterScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
pass
with variable_scope.variable_scope(outer, "default", []):
@@ -827,7 +827,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_2/default/scope2/")
def testVarOpScopeNestedOuterScope(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer") as outer:
with variable_scope.variable_scope(outer, "default", []):
self.assertEqual(
@@ -851,7 +851,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(sc2, "outer_1/default/scope2/")
def testBasicWhenAuxiliaryNameScopeIsFalse(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"scope", auxiliary_name_scope=False) as scope:
self.assertEqual(scope.original_name_scope, "")
@@ -886,7 +886,7 @@ class VariableScopeTest(test.TestCase):
constant_op.constant([], name="c").name, "outer/inner/c:0")
def testCreatedByDefaultNameWhenAuxiliaryNameScopeIsFalse(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
None, default_name="default", auxiliary_name_scope=False) as scope:
self.assertEqual(scope.original_name_scope, "")
@@ -910,7 +910,7 @@ class VariableScopeTest(test.TestCase):
constant_op.constant([], name="c").name, "outer/default/c:0")
def testReenterRootScopeWhenAuxiliaryNameScopeIsFalse(self):
- with self.test_session():
+ with self.cached_session():
root_scope = variable_scope.get_variable_scope()
with variable_scope.variable_scope(
root_scope, auxiliary_name_scope=False) as scope:
@@ -927,7 +927,7 @@ class VariableScopeTest(test.TestCase):
constant_op.constant([], name="c1").name, "outer/c1:0")
def testAuxiliaryNameScopeIsInvalid(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "auxiliary_name_scope"):
with variable_scope.variable_scope(
None, default_name="scope", auxiliary_name_scope="invalid"):
@@ -947,7 +947,7 @@ class VariableScopeTest(test.TestCase):
def testReuseScopeWithoutNameScopeCollision(self):
# Github issue: #13429
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope("outer"):
with variable_scope.variable_scope("inner") as inner:
pass
@@ -1021,7 +1021,7 @@ class VariableScopeTest(test.TestCase):
self.assertEqual(varname_type[1], ("y", dtypes.int64))
def testGetCollection(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.get_variable("testGetCollection_a", [])
_ = variable_scope.get_variable(
"testGetCollection_b", [], trainable=False)
@@ -1075,7 +1075,7 @@ class VariableScopeTest(test.TestCase):
])
def testGetTrainableVariablesWithGetVariable(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.get_variable("testGetTrainableVariables_a", [])
with variable_scope.variable_scope(
"testGetTrainableVariables_foo") as scope:
@@ -1111,7 +1111,7 @@ class VariableScopeTest(test.TestCase):
trainable=True)
def testGetTrainableVariablesWithVariable(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
with variable_scope.variable_scope(
"testGetTrainableVariables_foo") as scope:
@@ -1150,7 +1150,7 @@ class VariableScopeTest(test.TestCase):
trainable=True)
def testGetGlobalVariables(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.get_variable("testGetGlobalVariables_a", [])
with variable_scope.variable_scope("testGetGlobalVariables_foo") as scope:
_ = variable_scope.get_variable("testGetGlobalVariables_b", [])
@@ -1160,7 +1160,7 @@ class VariableScopeTest(test.TestCase):
"testGetGlobalVariables_b:0"])
def testGetLocalVariables(self):
- with self.test_session():
+ with self.cached_session():
_ = variable_scope.get_variable(
"a", [], collections=[ops.GraphKeys.LOCAL_VARIABLES])
with variable_scope.variable_scope("foo") as scope:
@@ -1396,7 +1396,7 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertEqual("scope/v/0:0", true_vars[0].name)
self.assertEqual("scope/v/1:0", true_vars[1].name)
self.assertEqual("custom_getter/add:0", v.name)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
np_vars, np_v = sess.run([true_vars, v])
self.assertAllClose(np_v, sum(np_vars))
@@ -1436,7 +1436,7 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertEqual(template % (1, 1, 0), true_vars[6].name)
self.assertEqual(template % (1, 1, 1), true_vars[7].name)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
np_vars, np_v = sess.run([true_vars, v])
# take products of sums of products
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2b9c62ad6f..2e7975667c 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -42,7 +42,7 @@ from tensorflow.python.util import compat
class VariablesTestCase(test.TestCase):
def testInitialization(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable(0.0)
self.assertEqual("Variable:0", var0.name)
self.assertEqual("Variable", var0._shared_name)
@@ -69,7 +69,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(1.1, var1.eval())
def testInitializationOrder(self):
- with self.test_session():
+ with self.cached_session():
rnd = variables.Variable(random_ops.random_uniform([3, 6]), name="rnd")
self.assertEqual("rnd:0", rnd.name)
self.assertEqual([3, 6], rnd.get_shape())
@@ -106,7 +106,7 @@ class VariablesTestCase(test.TestCase):
pass
def testAssignments(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(0.0)
plus_one = var.assign_add(1.0)
minus_one = var.assign_sub(2.0)
@@ -142,7 +142,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(4.0, var.eval())
def testZeroSizeStringAssign(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
array = variables.Variable(
initial_value=array_ops.zeros((0,), dtype=dtypes.string),
name="foo",
@@ -154,7 +154,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([], list(sess.run(copy_op)))
def _countUpToTest(self, dtype):
- with self.test_session():
+ with self.cached_session():
zero = constant_op.constant(0, dtype=dtype)
var = variables.Variable(zero)
count_up_to = var.count_up_to(3)
@@ -186,7 +186,7 @@ class VariablesTestCase(test.TestCase):
self._countUpToTest(dtypes.int64)
def testControlDepsNone(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(1.0)
with ops.control_dependencies([c]):
# d get the control dep.
@@ -199,7 +199,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([], var_x._ref().op.control_inputs) # pylint: disable=protected-access
def testControlFlow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
v0 = variables.Variable(0, name="v0")
var_dict = {}
@@ -248,7 +248,7 @@ class VariablesTestCase(test.TestCase):
control_flow_ops.while_loop(cond, body, [0, 0])
def testUseVariableAsTensor(self):
- with self.test_session():
+ with self.cached_session():
var_x = variables.Variable(2.0)
var_y = variables.Variable(3.0)
variables.global_variables_initializer().run()
@@ -257,7 +257,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(5.0, math_ops.add(var_x, var_y).eval())
def testZeroSizeVarSameAsConst(self):
- with self.test_session():
+ with self.cached_session():
zero_size_var = variables.Variable(array_ops.zeros([0, 2]))
zero_size_const = array_ops.ones([2, 0])
variable_mul = math_ops.matmul(zero_size_const, zero_size_var)
@@ -269,7 +269,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose([[0., 0.], [0., 0.]], variable_output)
def testCachingDevice(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(2.0)
self.assertEqual(var.device, var.value().device)
self.assertEqual(var.device, var.initialized_value().device)
@@ -279,7 +279,7 @@ class VariablesTestCase(test.TestCase):
self.assertTrue(var_cached.value().device.startswith("/job:foo"))
def testCollections(self):
- with self.test_session():
+ with self.cached_session():
var_x = variables.Variable(2.0)
var_y = variables.Variable(2.0, trainable=False)
var_z = variables.Variable(2.0, trainable=True)
@@ -294,7 +294,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([var_x, var_z, var_t], variables.trainable_variables())
def testCollectionsWithScope(self):
- with self.test_session():
+ with self.cached_session():
with ops.name_scope("scope_1"):
var_x = variables.Variable(2.0)
with ops.name_scope("scope_2"):
@@ -309,7 +309,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual([var_y], variables.trainable_variables("scope_2"))
def testOperators(self):
- with self.test_session():
+ with self.cached_session():
var_f = variables.Variable([2.0])
add = var_f + 0.0
radd = 1.0 + var_f
@@ -382,13 +382,13 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose([[20.0, 30.0], [40.0, 60.0]], rmatmul.eval())
def testSession(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var = variables.Variable([1, 12])
variables.global_variables_initializer().run()
self.assertAllClose([1, 12], sess.run(var))
def testDevicePlacement(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with ops.device("/cpu:0"):
var = variables.Variable([1, 12])
init_value = var.initialized_value()
@@ -408,7 +408,7 @@ class VariablesTestCase(test.TestCase):
def testInitializerFunction(self):
value = [[-42], [133.7]]
shape = [2, 1]
- with self.test_session():
+ with self.cached_session():
initializer = lambda: constant_op.constant(value)
v1 = variables.Variable(initializer, dtype=dtypes.float32)
@@ -443,7 +443,7 @@ class VariablesTestCase(test.TestCase):
constraint=constraint)
def testNoRefDataRace(self):
- with self.test_session():
+ with self.cached_session():
a = variables.Variable([1, 2, 3], dtype=dtypes.float32)
b = variables.Variable(a.initialized_value() + 2)
c = variables.Variable(b.initialized_value() + 2)
@@ -453,7 +453,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllEqual(c.eval(), [5, 6, 7])
def testInitializerFunctionDevicePlacement(self):
- with self.test_session():
+ with self.cached_session():
initializer = lambda: constant_op.constant(42.0)
with ops.device("/cpu:100"):
v1 = variables.Variable(initializer, dtype=dtypes.float32, name="v1")
@@ -471,11 +471,11 @@ class VariablesTestCase(test.TestCase):
self.assertEqual(expected_group_v2, i.op.colocation_groups())
def testVariableDefInitializedInstances(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v_def = variables.Variable(
initial_value=constant_op.constant(3.0)).to_proto()
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# v describes a VariableDef-based variable without an initial value.
v = variables.Variable(variable_def=v_def)
self.assertEqual(3.0, sess.run(v.initialized_value()))
@@ -486,7 +486,7 @@ class VariablesTestCase(test.TestCase):
self.assertEqual(1.0, v.initialized_value().eval())
v_def.ClearField("initial_value_name")
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
# Restoring a legacy VariableDef proto that does not have
# initial_value_name set should still work.
v = variables.Variable(variable_def=v_def)
@@ -514,7 +514,7 @@ class VariablesTestCase(test.TestCase):
.trainable)
def testLoad(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(np.zeros((5, 5), np.float32))
variables.global_variables_initializer().run()
var.load(np.ones((5, 5), np.float32))
@@ -540,12 +540,12 @@ class VariablesTestCase(test.TestCase):
class IsInitializedTest(test.TestCase):
def testNoVars(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
uninited = variables.report_uninitialized_variables()
self.assertEqual(0, sess.run(uninited).size)
def testAssertVariablesInitialized(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2], name="v")
w = variables.Variable([3, 4], name="w")
_ = v, w
@@ -555,7 +555,7 @@ class IsInitializedTest(test.TestCase):
self.assertEqual(0, sess.run(uninited).size)
def testVariableList(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2], name="v")
w = variables.Variable([3, 4], name="w")
uninited = variables.report_uninitialized_variables()
@@ -566,14 +566,14 @@ class IsInitializedTest(test.TestCase):
self.assertEqual(0, sess.run(uninited).size)
def testZeroSizeVarInitialized(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable(array_ops.zeros([0, 2]), name="v")
uninited = variables.report_uninitialized_variables()
v.initializer.run() # not strictly necessary
self.assertEqual(0, sess.run(uninited).size)
def testTrainingWithZeroSizeVar(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
a = variables.Variable(array_ops.zeros([0, 2]))
b = variables.Variable(array_ops.ones([2, 2]))
objective = math_ops.reduce_sum(b + math_ops.matmul(
@@ -592,7 +592,7 @@ class ObsoleteIsInitializedTest(test.TestCase):
self.assertEqual(None, variables.assert_variables_initialized())
def testVariables(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2])
w = variables.Variable([3, 4])
_ = v, w
@@ -603,7 +603,7 @@ class ObsoleteIsInitializedTest(test.TestCase):
sess.run(inited)
def testVariableList(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.Variable([1, 2])
w = variables.Variable([3, 4])
inited = variables.assert_variables_initialized([v])
diff --git a/tensorflow/python/kernel_tests/weights_broadcast_test.py b/tensorflow/python/kernel_tests/weights_broadcast_test.py
index eda2856e0b..85f9abc69f 100644
--- a/tensorflow/python/kernel_tests/weights_broadcast_test.py
+++ b/tensorflow/python/kernel_tests/weights_broadcast_test.py
@@ -44,7 +44,7 @@ class AssertBroadcastableTest(test.TestCase):
values_placeholder = array_ops.placeholder(dtypes_lib.float32)
dynamic_op = weights_broadcast_ops.assert_broadcastable(
weights=weights_placeholder, values=values_placeholder)
- with self.test_session():
+ with self.cached_session():
static_op.run()
dynamic_op.run(feed_dict={
weights_placeholder: weights,
@@ -100,7 +100,7 @@ class AssertBroadcastableTest(test.TestCase):
values_placeholder = array_ops.placeholder(dtypes_lib.float32)
dynamic_op = weights_broadcast_ops.assert_broadcastable(
weights=weights_placeholder, values=values_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, error_msg):
dynamic_op.run(feed_dict={
weights_placeholder: weights,
@@ -157,7 +157,7 @@ class BroadcastWeightsTest(test.TestCase):
values_placeholder = array_ops.placeholder(dtypes_lib.float32)
dynamic_op = weights_broadcast_ops.broadcast_weights(
weights=weights_placeholder, values=values_placeholder)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected, static_op.eval())
self.assertAllEqual(expected, dynamic_op.eval(feed_dict={
weights_placeholder: weights,
@@ -227,7 +227,7 @@ class BroadcastWeightsTest(test.TestCase):
values_placeholder = array_ops.placeholder(dtypes_lib.float32)
dynamic_op = weights_broadcast_ops.broadcast_weights(
weights=weights_placeholder, values=values_placeholder)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(errors_impl.OpError, error_msg):
dynamic_op.eval(feed_dict={
weights_placeholder: weights,
diff --git a/tensorflow/python/kernel_tests/while_v2_test.py b/tensorflow/python/kernel_tests/while_v2_test.py
new file mode 100644
index 0000000000..3a070544e8
--- /dev/null
+++ b/tensorflow/python/kernel_tests/while_v2_test.py
@@ -0,0 +1,276 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for while_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import while_v2
+from tensorflow.python.ops.control_flow_ops import while_loop as while_loop_v1
+from tensorflow.python.ops.while_v2 import while_loop as while_loop_v2
+from tensorflow.python.platform import test
+
+
+class WhileV2Test(test.TestCase, parameterized.TestCase):
+
+ def testSingleLoopVar(self):
+ x = constant_op.constant(2.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v * v, [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ def testMultipleLoopVarsBasic(self):
+ x = constant_op.constant(5.)
+ y = constant_op.constant(3.)
+
+ # x = 5.
+ # y = 3.
+ # while x < 45.:
+ # x = x * y
+ ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, w), [x, y])
+ # ret = [x*y^2, y]
+
+ # Note: This is simply d_ret[0]/d_x since d_ret[1]/d_x is 0.
+ grad = gradients_impl.gradients(ret, [x]) # [2*x*y]
+ with self.cached_session() as sess:
+ self.assertSequenceEqual(sess.run(ret), [45., 3.])
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testMultipleLoopVars(self):
+ x = constant_op.constant(5.)
+ y = constant_op.constant(3.)
+
+ # x = 5.
+ # y = 3.
+ # while x < 45.:
+ # x = x * y
+ # y = x + y
+ ret = while_loop_v2(lambda v, _: v < 45., lambda v, w: (v * w, v + w),
+ [x, y])
+ # ret = [y*x**2 + x*y**2, x*y + x + y]
+
+ gradx_0 = gradients_impl.gradients(ret[0], [x]) # [2*x*y + y**2]
+ gradx_1 = gradients_impl.gradients(ret[1], [x]) # [y + 1]
+ gradx_2 = gradients_impl.gradients(ret, [x]) # [2*x*y + y**2 + 2*y + 1]
+ grady_0 = gradients_impl.gradients(ret[0], [y]) # [2*x*y + x**2]
+ grady_1 = gradients_impl.gradients(ret[1], [y]) # [x + 1]
+ grady_2 = gradients_impl.gradients(ret, [y]) # [2*x*y + x**2 + x + 1]
+ with self.cached_session() as sess:
+ self.assertSequenceEqual(sess.run(ret), [120., 23.])
+ self.assertSequenceEqual(sess.run(gradx_0), [39.])
+ self.assertSequenceEqual(sess.run(gradx_1), [4.])
+ self.assertSequenceEqual(sess.run(gradx_2), [43.])
+ self.assertSequenceEqual(sess.run(grady_0), [55.])
+ self.assertSequenceEqual(sess.run(grady_1), [6.])
+ self.assertSequenceEqual(sess.run(grady_2), [61.])
+
+ def testMultipleWhileLoops(self):
+ x = constant_op.constant(2.)
+ ret1 = while_loop_v2(lambda v: v < 4., lambda v: v * v, [x]) # x**2
+ ret2 = while_loop_v2(lambda v: v < 16., lambda v: v * v, ret1) # x**4
+ grad = gradients_impl.gradients(ret2, [x]) # 4x**3
+ grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
+ with self.cached_session() as sess:
+ self.assertSequenceEqual(sess.run(grad), [32.])
+ self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+ def testDoubleDerivative(self):
+ x = constant_op.constant(2.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v**2, [x]) # x**4
+ grad = gradients_impl.gradients(ret, [x]) # 4x**3
+ grad_grad = gradients_impl.gradients(grad, [x]) # 12x**2
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+ self.assertSequenceEqual(sess.run(grad_grad), [48.])
+
+ def testPruning(self):
+ x = constant_op.constant(1)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=x.dtype, element_shape=x.shape)
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5
+
+ def Body(x, tl):
+ return x + 1, list_ops.tensor_list_push_back(tl, x)
+
+ outputs = while_loop_v1(Cond, Body, [x, tensor_list])
+
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ train_op.append(outputs[0])
+
+ def GetOptimizedGraph():
+ mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+ memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
+ return tf_optimizer.OptimizeGraph(rewriter_config, mg)
+
+ g = GetOptimizedGraph()
+ self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)
+
+ stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
+ train_op.append(stack)
+ g = GetOptimizedGraph()
+ self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
+
+ def testCaptureExternalTensorInCond(self):
+ x = constant_op.constant(2.)
+ y = constant_op.constant(1.)
+ ret = while_loop_v2(lambda v: v + y < 9., lambda v: v * 3., [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret), 18.)
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testCaptureExternalTensorInBody(self):
+ x = constant_op.constant(2.)
+ y = constant_op.constant(3.)
+ ret = while_loop_v2(lambda v: v < 8., lambda v: v * y, [x])
+ grad = gradients_impl.gradients(ret, [x])
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret), 18.)
+ self.assertSequenceEqual(sess.run(grad), [9.])
+
+ def testLoopWithTensorListPushBack(self):
+ x = constant_op.constant(2.)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5.
+
+ def Body(x, tl):
+ tl = list_ops.tensor_list_push_back(tl, x)
+ tl = list_ops.tensor_list_push_back(tl, constant_op.constant(100.))
+ return x**2., tl
+
+ ret = while_loop_v2(Cond, Body, [x, tensor_list])
+ grad = gradients_impl.gradients(ret[0], x)
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret[0]), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ def testDuplicateAccumulator(self):
+ x = constant_op.constant(2.)
+
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=ScalarShape())
+
+ def Cond(x, tl):
+ del tl # Unused for Cond.
+ return x < 5.
+
+ def Body(x, tl):
+ # There is an accumulator in the loop already so we should not add
+ # another.
+ tl = list_ops.tensor_list_push_back(tl, x)
+ return x**2., tl
+
+ ret = while_loop_v2(Cond, Body, [x, tensor_list])
+
+ for op in ops.get_default_graph().get_operations():
+ if op.type == "While":
+ while_op = op
+
+ body_graph = while_v2._get_body_graph(while_op)
+ # body_graph.inputs: [counter_arg, x_arg, tl_arg, *accumulators]
+ x_input_t = body_graph.inputs[1]
+ accumulator_count = len(
+ [c for c in x_input_t.consumers() if c.type == "TensorListPushBack"])
+ self.assertEqual(accumulator_count, 1)
+
+ grad = gradients_impl.gradients(ret[0], x)
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(ret[0]), 16.)
+ self.assertSequenceEqual(sess.run(grad), [32.])
+
+ @parameterized.named_parameters(
+ ("UnknownShape", None),
+ ("PartiallyDefinedShape", [None, 2]),
+ ("FullyDefinedShape", [1, 2]),
+ )
+ def testTensorListOutputElementShape(self, shape):
+
+ def MatchShape(actual_tensor_shape):
+ # Compare the shapes, treating None dimensions as equal. We do not
+ # directly check actual_tensor_shape and tf.TensorShape(shape) for
+ # equality because tf.Dimension.__eq__ returns None if either dimension is
+ # None.
+ if shape is None:
+ self.assertIsNone(actual_tensor_shape.dims)
+ else:
+ self.assertListEqual(actual_tensor_shape.as_list(), shape)
+
+ def GetAccumulatorForInputAtIndex(while_op, idx):
+ body_graph = while_v2._get_body_graph(while_op)
+ y_input_t = body_graph.inputs[idx]
+ push_back_node = [c for c in y_input_t.consumers()
+ if c.type == "TensorListPushBack"][0]
+ output_idx = body_graph.outputs.index(push_back_node.outputs[0])
+ return while_op.outputs[output_idx]
+
+ x = constant_op.constant(2.)
+ y = array_ops.placeholder(dtype=dtypes.float32, shape=shape)
+
+ # Forward pass.
+ ret = while_loop_v2(lambda v, u: v < 8., lambda v, u: (v * v, u), [x, y])
+ while_op = ret[0].op
+ # Get the TensorList output of While op containing the accumulated values
+ # of y.
+ # while_op.inputs: [counter_arg, x_arg, y_arg, *accumulators]
+ output = GetAccumulatorForInputAtIndex(while_op, 2)
+ _, val = list_ops.tensor_list_pop_back(output,
+ element_dtype=dtypes.float32)
+ MatchShape(val.shape)
+
+ # Gradient pass.
+ grad = gradients_impl.gradients(ret[1], y)
+ grad_while_op = grad[0].op
+ # Get the TensorList output of gradient While op containing the accumulated
+ # values of grad_y.
+ # grad_while_op.inputs:
+ # [counter_arg, total_iters_arg, grad_x_arg, grad_y_arg, *other_args]
+ grad_output = GetAccumulatorForInputAtIndex(grad_while_op, 4)
+ _, val = list_ops.tensor_list_pop_back(grad_output,
+ element_dtype=dtypes.float32)
+ MatchShape(val.shape)
+
+
+def ScalarShape():
+ return ops.convert_to_tensor([], dtype=dtypes.int32)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 60c726d54c..729885169e 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -153,13 +153,13 @@ class XentTest(test.TestCase):
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
def testShapeMismatch(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
gen_nn_ops.softmax_cross_entropy_with_logits(
[[0., 1.], [2., 3.]], [[0., 1., 0.], [1., 0., 0.]])
def testNotMatrix(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
gen_nn_ops.softmax_cross_entropy_with_logits([0., 1., 2., 3.],
[0., 1., 0., 1.])
@@ -180,7 +180,7 @@ class XentTest(test.TestCase):
np.array([[0., 0., 0., 1.], [0., .5, .5, 0.]]).astype(np.float64))
def testGradient(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l = constant_op.constant(
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
shape=[3, 4],
@@ -207,7 +207,7 @@ class XentTest(test.TestCase):
self.assertLess(err, 5e-8)
def testGradientLabelWithV2(self):
- with self.test_session():
+ with self.cached_session():
l = constant_op.constant(
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
shape=[3, 4],
@@ -225,7 +225,7 @@ class XentTest(test.TestCase):
self.assertLess(err, 5e-8)
def testSecondGradient(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l = constant_op.constant(
[
0.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, 0.0, 0.0, 0.0, 0.0, 0.5 / 3, 0.0,
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index c8b883350d..a7f57e94e3 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -2787,4 +2787,65 @@ def quantize(input, # pylint: disable=redefined-builtin
name=name)
+@tf_export("searchsorted")
+def searchsorted(sorted_sequence,
+ values,
+ side="left",
+ out_type=dtypes.int32,
+ name=None):
+ """Searches input tensor for values on the innermost dimension.
+
+ A 2-D example:
+
+ ```
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = searchsorted(sorted_sequence, values, side="left")
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+
+ result = searchsorted(sorted_sequence, values, side="right")
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+ ```
+
+ Args:
+ sorted_sequence: N-D `Tensor` containing a sorted sequence.
+ values: N-D `Tensor` containing the search values.
+ side: 'left' or 'right'; 'left' corresponds to lower_bound and 'right' to
+ upper_bound.
+ out_type: The output type (`int32` or `int64`). Default is `tf.int32`.
+ name: Optional name for the operation.
+
+ Returns:
+ An N-D `Tensor` the size of values containing the result of applying either
+ lower_bound or upper_bound (depending on side) to each value. The result
+ is not a global index to the entire `Tensor`, but the index in the last
+ dimension.
+
+ Raises:
+ ValueError: If the last dimension of `sorted_sequence >= 2^31-1` elements.
+ If the total size of values exceeds `2^31 - 1` elements.
+ If the first `N-1` dimensions of the two tensors don't match.
+ """
+ sequence_size = shape_internal(sorted_sequence)[-1]
+ values_size = shape_internal(values)[-1]
+ sorted_sequence_2d = reshape(sorted_sequence, [-1, sequence_size])
+ values_2d = reshape(values, [-1, values_size])
+ if side == "right":
+ output = gen_array_ops.upper_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ elif side == "left":
+ output = gen_array_ops.lower_bound(sorted_sequence_2d, values_2d, out_type,
+ name)
+ else:
+ raise ValueError("side must be either 'right' or 'left'. Saw: %s." % side)
+ return reshape(output, shape_internal(values))
+
+
quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index f7cbfe0312..720f9f4d41 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -24,11 +24,17 @@ from tensorflow.python.ops import resources
# Re-exporting ops used by other modules.
# pylint: disable=unused-import
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
+from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
# pylint: enable=unused-import
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index c6a6b2a7fa..f8b1ddb140 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -119,7 +119,11 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
- return tuple(tensors[:num_cond_outputs])
+ result = tuple(tensors[:num_cond_outputs])
+ if len(result) == 1:
+ return result[0]
+ else:
+ return result
@ops.RegisterGradient("If")
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index e3c1aa3d5a..87f8bd85a5 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -61,7 +61,7 @@ from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
-_ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -610,9 +610,10 @@ def _EnforceShapeInvariant(merge_var, next_var):
"less-specific shape." %
(input_t.name, input_t.shape, n_shape))
else:
- if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
- raise TypeError("Type %s not supported" % type(var))
- if isinstance(var, ops.IndexedSlices):
+ if not isinstance(merge_var,
+ (ops.IndexedSlices, sparse_tensor.SparseTensor)):
+ raise TypeError("Type %s not supported" % type(merge_var))
+ if isinstance(merge_var, ops.IndexedSlices):
m_values_shape = merge_var.values.get_shape()
m_indices_shape = merge_var.indices.get_shape()
m_shape_shape = tensor_shape.TensorShape(None)
@@ -2026,7 +2027,7 @@ def cond(pred,
```
"""
- if _ENABLE_COND_V2:
+ if ENABLE_COND_V2 and not context.executing_eagerly():
return cond_v2_impl.cond_v2(pred, true_fn, false_fn, name)
# We needed to make true_fn/false_fn keyword arguments for
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 908e793902..32d455bdad 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -242,11 +242,11 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
If `merge_repeated` is `True`, merge repeated classes in the output beams.
This means that if consecutive entries in a beam are the same,
- only the first of these is emitted. That is, when the top path
- is `A B B B B`, the return value is:
+ only the first of these is emitted. That is, when the sequence is
+ `A B B * B * B` (where '*' is the blank label), the return value is:
* `A B` if `merge_repeated = True`.
- * `A B B B B` if `merge_repeated = False`.
+ * `A B B B` if `merge_repeated = False`.
Args:
inputs: 3-D `float` `Tensor`, size
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 99d30b0bd1..2ba1ea6744 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -98,10 +98,13 @@ class Beta(distribution.Distribution):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Create a batch of three Beta distributions.
alpha = [1, 2, 3]
beta = [1, 2, 3]
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
dist.sample([4, 5]) # Shape [4, 5, 3]
@@ -117,7 +120,7 @@ class Beta(distribution.Distribution):
# Create batch_shape=[2, 3] via parameter broadcast:
alpha = [[1.], [2]] # Shape [2, 1]
beta = [3., 4, 5] # Shape [3]
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
# alpha broadcast as: [[1., 1, 1,],
# [2, 2, 2]]
@@ -138,7 +141,7 @@ class Beta(distribution.Distribution):
```python
alpha = tf.constant(1.0)
beta = tf.constant(2.0)
- dist = tf.distributions.Beta(alpha, beta)
+ dist = tfd.Beta(alpha, beta)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/bijector_impl.py b/tensorflow/python/ops/distributions/bijector_impl.py
index b65e64d401..9c63385dd0 100644
--- a/tensorflow/python/ops/distributions/bijector_impl.py
+++ b/tensorflow/python/ops/distributions/bijector_impl.py
@@ -825,10 +825,21 @@ class Bijector(object):
min_event_ndims=self.inverse_min_event_ndims,
event_ndims=event_ndims)):
if not self._is_injective: # No caching for non-injective
- ildjs = self._inverse_log_det_jacobian(y, **kwargs)
- return tuple(self._reduce_jacobian_det_over_event(
- y, ildj, self.inverse_min_event_ndims, event_ndims)
- for ildj in ildjs)
+ try:
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError as original_exception:
+ try:
+ x = self._inverse(y, **kwargs)
+ fldjs = self._forward_log_det_jacobian(x, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, -fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError:
+ raise original_exception
+
mapping = self._lookup(y=y, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return mapping.ildj_map[event_ndims]
@@ -917,11 +928,21 @@ class Bijector(object):
return -1. * self._constant_ildj_map[event_ndims]
x = ops.convert_to_tensor(x, name="x")
self._maybe_assert_dtype(x)
- if not self._is_injective:
- fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
- return tuple(self._reduce_jacobian_det_over_event(
- x, fldj, self.forward_min_event_ndims, event_ndims)
- for fldj in fldjs)
+ if not self._is_injective: # No caching for non-injective
+ try:
+ fldjs = self._forward_log_det_jacobian(x, **kwargs) # No caching.
+ return tuple(self._reduce_jacobian_det_over_event(
+ x, fldj, self.forward_min_event_ndims, event_ndims)
+ for fldj in fldjs)
+ except NotImplementedError as original_exception:
+ try:
+ y = self._forward(x, **kwargs)
+ ildjs = self._inverse_log_det_jacobian(y, **kwargs)
+ return tuple(self._reduce_jacobian_det_over_event(
+ y, -ildj, self.inverse_min_event_ndims, event_ndims)
+ for ildj in ildjs)
+ except NotImplementedError:
+ raise original_exception
mapping = self._lookup(x=x, kwargs=kwargs)
if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
return -mapping.ildj_map[event_ndims]
@@ -1011,12 +1032,6 @@ class Bijector(object):
def _reduce_jacobian_det_over_event(
self, y, ildj, min_event_ndims, event_ndims):
"""Reduce jacobian over event_ndims - min_event_ndims."""
-
- if not self.is_constant_jacobian:
- return math_ops.reduce_sum(
- ildj,
- self._get_event_reduce_dims(min_event_ndims, event_ndims))
-
# In this case, we need to tile the Jacobian over the event and reduce.
y_rank = array_ops.rank(y)
y_shape = array_ops.shape(y)[
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index dd25fce2ec..fbbacf2521 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -69,7 +69,7 @@ class Categorical(distribution.Distribution):
The Categorical distribution is closely related to the `OneHotCategorical` and
`Multinomial` distributions. The Categorical distribution can be intuited as
generating samples according to `argmax{ OneHotCategorical(probs) }` itself
- being identical to `argmax{ Multinomial(probs, total_count=1) }.
+ being identical to `argmax{ Multinomial(probs, total_count=1) }`.
#### Mathematical Details
@@ -83,7 +83,7 @@ class Categorical(distribution.Distribution):
The number of classes, `K`, must not exceed:
- the largest integer representable by `self.dtype`, i.e.,
- `2**(mantissa_bits+1)` (IEE754),
+ `2**(mantissa_bits+1)` (IEEE 754),
- the maximum `Tensor` index, i.e., `2**31-1`.
In other words,
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 9104a1d071..415249a958 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -104,10 +104,13 @@ class Dirichlet(distribution.Distribution):
#### Examples
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 3]
@@ -129,7 +132,7 @@ class Dirichlet(distribution.Distribution):
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
[4, 5, 6]] # shape: [2, 3]
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
dist.sample([4, 5]) # shape: [4, 5, 2, 3]
@@ -144,7 +147,7 @@ class Dirichlet(distribution.Distribution):
```python
alpha = tf.constant([1.0, 2.0, 3.0])
- dist = tf.distributions.Dirichlet(alpha)
+ dist = tfd.Dirichlet(alpha)
samples = dist.sample(5) # Shape [5, 3]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 578e7b7dd2..76d980679e 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -601,7 +601,8 @@ class Distribution(_BaseDistribution):
return type(self)(**parameters)
def _batch_shape_tensor(self):
- raise NotImplementedError("batch_shape_tensor is not implemented")
+ raise NotImplementedError(
+ "batch_shape_tensor is not implemented: {}".format(type(self).__name__))
def batch_shape_tensor(self, name="batch_shape_tensor"):
"""Shape of a single sample from a single event index as a 1-D `Tensor`.
@@ -640,7 +641,8 @@ class Distribution(_BaseDistribution):
return tensor_shape.as_shape(self._batch_shape())
def _event_shape_tensor(self):
- raise NotImplementedError("event_shape_tensor is not implemented")
+ raise NotImplementedError(
+ "event_shape_tensor is not implemented: {}".format(type(self).__name__))
def event_shape_tensor(self, name="event_shape_tensor"):
"""Shape of a single sample from a single batch as a 1-D int32 `Tensor`.
@@ -701,7 +703,8 @@ class Distribution(_BaseDistribution):
name="is_scalar_batch")
def _sample_n(self, n, seed=None):
- raise NotImplementedError("sample_n is not implemented")
+ raise NotImplementedError("sample_n is not implemented: {}".format(
+ type(self).__name__))
def _call_sample_n(self, sample_shape, seed, name, **kwargs):
with self._name_scope(name, values=[sample_shape]):
@@ -733,15 +736,19 @@ class Distribution(_BaseDistribution):
return self._call_sample_n(sample_shape, seed, name)
def _log_prob(self, value):
- raise NotImplementedError("log_prob is not implemented")
+ raise NotImplementedError("log_prob is not implemented: {}".format(
+ type(self).__name__))
def _call_log_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_prob(value, **kwargs)
- except NotImplementedError:
- return math_ops.log(self._prob(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log(self._prob(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_prob(self, value, name="log_prob"):
"""Log probability density/mass function.
@@ -757,15 +764,19 @@ class Distribution(_BaseDistribution):
return self._call_log_prob(value, name)
def _prob(self, value):
- raise NotImplementedError("prob is not implemented")
+ raise NotImplementedError("prob is not implemented: {}".format(
+ type(self).__name__))
def _call_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._prob(value, **kwargs)
- except NotImplementedError:
- return math_ops.exp(self._log_prob(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.exp(self._log_prob(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def prob(self, value, name="prob"):
"""Probability density/mass function.
@@ -781,15 +792,19 @@ class Distribution(_BaseDistribution):
return self._call_prob(value, name)
def _log_cdf(self, value):
- raise NotImplementedError("log_cdf is not implemented")
+ raise NotImplementedError("log_cdf is not implemented: {}".format(
+ type(self).__name__))
def _call_log_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_cdf(value, **kwargs)
- except NotImplementedError:
- return math_ops.log(self._cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log(self._cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_cdf(self, value, name="log_cdf"):
"""Log cumulative distribution function.
@@ -815,15 +830,19 @@ class Distribution(_BaseDistribution):
return self._call_log_cdf(value, name)
def _cdf(self, value):
- raise NotImplementedError("cdf is not implemented")
+ raise NotImplementedError("cdf is not implemented: {}".format(
+ type(self).__name__))
def _call_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._cdf(value, **kwargs)
- except NotImplementedError:
- return math_ops.exp(self._log_cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.exp(self._log_cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def cdf(self, value, name="cdf"):
"""Cumulative distribution function.
@@ -845,15 +864,20 @@ class Distribution(_BaseDistribution):
return self._call_cdf(value, name)
def _log_survival_function(self, value):
- raise NotImplementedError("log_survival_function is not implemented")
+ raise NotImplementedError(
+ "log_survival_function is not implemented: {}".format(
+ type(self).__name__))
def _call_log_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._log_survival_function(value, **kwargs)
- except NotImplementedError:
- return math_ops.log1p(-self.cdf(value, **kwargs))
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.log1p(-self.cdf(value, **kwargs))
+ except NotImplementedError:
+ raise original_exception
def log_survival_function(self, value, name="log_survival_function"):
"""Log survival function.
@@ -880,15 +904,19 @@ class Distribution(_BaseDistribution):
return self._call_log_survival_function(value, name)
def _survival_function(self, value):
- raise NotImplementedError("survival_function is not implemented")
+ raise NotImplementedError("survival_function is not implemented: {}".format(
+ type(self).__name__))
def _call_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
value = ops.convert_to_tensor(value, name="value")
try:
return self._survival_function(value, **kwargs)
- except NotImplementedError:
- return 1. - self.cdf(value, **kwargs)
+ except NotImplementedError as original_exception:
+ try:
+ return 1. - self.cdf(value, **kwargs)
+ except NotImplementedError:
+ raise original_exception
def survival_function(self, value, name="survival_function"):
"""Survival function.
@@ -912,7 +940,8 @@ class Distribution(_BaseDistribution):
return self._call_survival_function(value, name)
def _entropy(self):
- raise NotImplementedError("entropy is not implemented")
+ raise NotImplementedError("entropy is not implemented: {}".format(
+ type(self).__name__))
def entropy(self, name="entropy"):
"""Shannon entropy in nats."""
@@ -920,7 +949,8 @@ class Distribution(_BaseDistribution):
return self._entropy()
def _mean(self):
- raise NotImplementedError("mean is not implemented")
+ raise NotImplementedError("mean is not implemented: {}".format(
+ type(self).__name__))
def mean(self, name="mean"):
"""Mean."""
@@ -928,7 +958,8 @@ class Distribution(_BaseDistribution):
return self._mean()
def _quantile(self, value):
- raise NotImplementedError("quantile is not implemented")
+ raise NotImplementedError("quantile is not implemented: {}".format(
+ type(self).__name__))
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
@@ -955,7 +986,8 @@ class Distribution(_BaseDistribution):
return self._call_quantile(value, name)
def _variance(self):
- raise NotImplementedError("variance is not implemented")
+ raise NotImplementedError("variance is not implemented: {}".format(
+ type(self).__name__))
def variance(self, name="variance"):
"""Variance.
@@ -979,11 +1011,15 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._variance()
- except NotImplementedError:
- return math_ops.square(self._stddev())
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.square(self._stddev())
+ except NotImplementedError:
+ raise original_exception
def _stddev(self):
- raise NotImplementedError("stddev is not implemented")
+ raise NotImplementedError("stddev is not implemented: {}".format(
+ type(self).__name__))
def stddev(self, name="stddev"):
"""Standard deviation.
@@ -1008,11 +1044,15 @@ class Distribution(_BaseDistribution):
with self._name_scope(name):
try:
return self._stddev()
- except NotImplementedError:
- return math_ops.sqrt(self._variance())
+ except NotImplementedError as original_exception:
+ try:
+ return math_ops.sqrt(self._variance())
+ except NotImplementedError:
+ raise original_exception
def _covariance(self):
- raise NotImplementedError("covariance is not implemented")
+ raise NotImplementedError("covariance is not implemented: {}".format(
+ type(self).__name__))
def covariance(self, name="covariance"):
"""Covariance.
@@ -1054,7 +1094,8 @@ class Distribution(_BaseDistribution):
return self._covariance()
def _mode(self):
- raise NotImplementedError("mode is not implemented")
+ raise NotImplementedError("mode is not implemented: {}".format(
+ type(self).__name__))
def mode(self, name="mode"):
"""Mode."""
@@ -1080,7 +1121,7 @@ class Distribution(_BaseDistribution):
where `F` denotes the support of the random variable `X ~ P`.
Args:
- other: `tf.distributions.Distribution` instance.
+ other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
@@ -1111,7 +1152,7 @@ class Distribution(_BaseDistribution):
denotes (Shanon) cross entropy, and `H[.]` denotes (Shanon) entropy.
Args:
- other: `tf.distributions.Distribution` instance.
+ other: `tfp.distributions.Distribution` instance.
name: Python `str` prepended to names of ops created by this function.
Returns:
@@ -1123,7 +1164,7 @@ class Distribution(_BaseDistribution):
return self._kl_divergence(other)
def __str__(self):
- return ("tf.distributions.{type_name}("
+ return ("tfp.distributions.{type_name}("
"\"{self_name}\""
"{maybe_batch_shape}"
"{maybe_event_shape}"
@@ -1139,7 +1180,7 @@ class Distribution(_BaseDistribution):
dtype=self.dtype.name))
def __repr__(self):
- return ("<tf.distributions.{type_name} "
+ return ("<tfp.distributions.{type_name} "
"'{self_name}'"
" batch_shape={batch_shape}"
" event_shape={event_shape}"
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index b631f0247c..3293cda874 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -100,8 +100,11 @@ class Gamma(distribution.Distribution):
#### Examples
```python
- dist = tf.distributions.Gamma(concentration=3.0, rate=2.0)
- dist2 = tf.distributions.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
+ dist = tfd.Gamma(concentration=3.0, rate=2.0)
+ dist2 = tfd.Gamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
Compute the gradients of samples w.r.t. the parameters:
@@ -109,7 +112,7 @@ class Gamma(distribution.Distribution):
```python
concentration = tf.constant(3.0)
rate = tf.constant(2.0)
- dist = tf.distributions.Gamma(concentration, rate)
+ dist = tfd.Gamma(concentration, rate)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py
index e3c6f3e789..fdeb97bf64 100644
--- a/tensorflow/python/ops/distributions/kullback_leibler.py
+++ b/tensorflow/python/ops/distributions/kullback_leibler.py
@@ -127,8 +127,8 @@ def cross_entropy(ref, other,
where `F` denotes the support of the random variable `X ~ P`.
Args:
- ref: `tf.distributions.Distribution` instance.
- other: `tf.distributions.Distribution` instance.
+ ref: `tfd.Distribution` instance.
+ other: `tfd.Distribution` instance.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
indicate the result is undefined. When `False`, an exception is raised
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index d0a987ba7c..2feaf806c0 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -71,15 +71,18 @@ class Normal(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar Normal distribution.
- dist = tf.distributions.Normal(loc=0., scale=3.)
+ dist = tfd.Normal(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22.
- dist = tf.distributions.Normal(loc=[1, 2.], scale=[11, 22.])
+ dist = tfd.Normal(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -94,7 +97,7 @@ class Normal(distribution.Distribution):
```python
# Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations.
- dist = tf.distributions.Normal(loc=1., scale=[11, 22.])
+ dist = tfd.Normal(loc=1., scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index e0cf6f86f1..e8d214bbe0 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -91,8 +91,11 @@ class StudentT(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ import tensorflow_probability as tfp
+ tfd = tfp.distributions
+
# Define a single scalar Student t distribution.
- single_dist = tf.distributions.StudentT(df=3)
+ single_dist = tfd.StudentT(df=3)
# Evaluate the pdf at 1, returning a scalar Tensor.
single_dist.prob(1.)
@@ -100,9 +103,7 @@ class StudentT(distribution.Distribution):
# Define a batch of two scalar valued Student t's.
# The first has degrees of freedom 2, mean 1, and scale 11.
# The second 3, 2 and 22.
- multi_dist = tf.distributions.StudentT(df=[2, 3],
- loc=[1, 2.],
- scale=[11, 22.])
+ multi_dist = tfd.StudentT(df=[2, 3], loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -117,7 +118,7 @@ class StudentT(distribution.Distribution):
```python
# Define a batch of two Student's t distributions.
# Both have df 2 and mean 1, but different scales.
- dist = tf.distributions.StudentT(df=2, loc=1, scale=[11, 22.])
+ dist = tfd.StudentT(df=2, loc=1, scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
@@ -130,7 +131,7 @@ class StudentT(distribution.Distribution):
df = tf.constant(2.0)
loc = tf.constant(2.0)
scale = tf.constant(11.0)
- dist = tf.distributions.StudentT(df=df, loc=loc, scale=scale)
+ dist = tfd.StudentT(df=df, loc=loc, scale=scale)
samples = dist.sample(5) # Shape [5]
loss = tf.reduce_mean(tf.square(samples)) # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
@@ -138,7 +139,6 @@ class StudentT(distribution.Distribution):
```
"""
- # pylint: enable=line-too-long
def __init__(self,
df,
diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py
index 3e480a79f5..ad848dfee6 100644
--- a/tensorflow/python/ops/distributions/util.py
+++ b/tensorflow/python/ops/distributions/util.py
@@ -155,7 +155,8 @@ def get_logits_and_probs(logits=None,
probs=None,
multidimensional=False,
validate_args=False,
- name="get_logits_and_probs"):
+ name="get_logits_and_probs",
+ dtype=None):
"""Converts logit to probabilities (or vice-versa), and returns both.
Args:
@@ -169,6 +170,7 @@ def get_logits_and_probs(logits=None,
`0 <= probs <= 1` (if not `multidimensional`) or that the last dimension
of `probs` sums to one.
name: A name for this operation (optional).
+ dtype: `tf.DType` to prefer when converting args to `Tensor`s.
Returns:
logits, probs: Tuple of `Tensor`s. If `probs` has an entry that is `0` or
@@ -183,7 +185,7 @@ def get_logits_and_probs(logits=None,
raise ValueError("Must pass probs or logits, but not both.")
if probs is None:
- logits = ops.convert_to_tensor(logits, name="logits")
+ logits = ops.convert_to_tensor(logits, name="logits", dtype=dtype)
if not logits.dtype.is_floating:
raise TypeError("logits must having floating type.")
# We can early return since we constructed probs and therefore know
@@ -194,7 +196,7 @@ def get_logits_and_probs(logits=None,
return logits, nn.softmax(logits, name="probs")
return logits, math_ops.sigmoid(logits, name="probs")
- probs = ops.convert_to_tensor(probs, name="probs")
+ probs = ops.convert_to_tensor(probs, name="probs", dtype=dtype)
if not probs.dtype.is_floating:
raise TypeError("probs must having floating type.")
@@ -524,6 +526,8 @@ def matrix_diag_transform(matrix, transform=None, name=None):
Example of heteroskedastic 2-D linear regression.
```python
+ tfd = tfp.distributions
+
# Get a trainable Cholesky factor.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
@@ -533,7 +537,7 @@ def matrix_diag_transform(matrix, transform=None, name=None):
mu = tf.contrib.layers.fully_connected(activations, 2)
# This is a fully trainable multivariate normal!
- dist = tf.contrib.distributions.MVNCholesky(mu, chol)
+ dist = tfd.MultivariateNormalTriL(mu, chol)
# Standard log loss. Minimizing this will "train" mu and chol, and then dist
# will be a distribution predicting labels as multivariate Gaussians.
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 6263041b8d..60d73a1693 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -550,9 +550,11 @@ def safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ if not isinstance(embedding_weights[0],
+ resource_variable_ops.ResourceVariable):
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index a4e7c84ae4..119d9522bd 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.gen_functional_ops import remote_call
# pylint: enable=unused-import
from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -263,7 +264,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
@tf_export("map_fn")
-def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
+def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
swap_memory=False, infer_shape=True, name=None):
"""map on the list of tensors unpacked from `elems` on dimension 0.
@@ -305,6 +306,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
instead.
+ When executing eagerly, map_fn does not execute in parallel even if
+ `parallel_iterations` is set to a value > 1. You can still get the
+ performance benefits of running a function in parallel by using the
+ `tf.contrib.eager.defun` decorator,
+
+ ```python
+ # Assume the function being used in map_fn is fn.
+ # To ensure map_fn calls fn in parallel, use the defun decorator.
+ @tf.contrib.eager.defun
+ def func(tensor):
+ return tf.map_fn(fn, tensor)
+ ```
+
+ Note that if you use the defun decorator, any non-TensorFlow Python code
+ that you may have written in your function won't get executed. See
+ `tf.contrib.eager.defun` for more details. The recommendation would be to
+ debug without defun but switch to defun to get performance benefits of
+ running map_fn in parallel.
+
Args:
fn: The callable to be performed. It accepts one argument, which will
have the same (possibly nested) structure as `elems`. Its output
@@ -317,7 +337,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
of Tensors differing from the structure of `elems`, then `dtype` is not
optional and must have the same structure as the output of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run
- in parallel.
+ in parallel. When graph building, the default value is 10. While executing
+ eagerly, the default value is set to 1.
back_prop: (optional) True enables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
@@ -363,6 +384,20 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
" SparseTensor(input.indices, map_fn(fn, input.values), "
"input.dense_shape)")
+ in_graph_mode = not context.executing_eagerly()
+ # Set the default number of parallel_iterations depending on graph/eager mode.
+ if in_graph_mode and not parallel_iterations:
+ parallel_iterations = 10
+ elif not in_graph_mode and not parallel_iterations:
+ parallel_iterations = 1
+
+ if not in_graph_mode and parallel_iterations > 1:
+ logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
+ "effect when executing eagerly. Consider calling map_fn"
+ " with tf.contrib.eager.defun to execute fn in "
+ "parallel.", 1)
+ parallel_iterations = 1
+
input_is_sequence = nest.is_sequence(elems)
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
def input_pack(x):
@@ -381,7 +416,6 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
elems_flat = input_flatten(elems)
- in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 3268b38b86..056015d6b6 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -184,7 +184,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
between_op_list.append(op)
# Clear the boolean so we won't add the inputs again.
reached_ops.remove(op)
- for inp in _Inputs(op, xs):
+ for inp in _NonEagerInputs(op, xs):
queue.append(inp.op)
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops
@@ -196,7 +196,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
# Initialize pending count for between ops.
pending_count = collections.defaultdict(int)
for op in between_op_list:
- for x in _Inputs(op, xs):
+ for x in _NonEagerInputs(op, xs):
if x.op in between_ops:
pending_count[x.op] += 1
@@ -260,6 +260,12 @@ def _DefaultGradYs(grad_ys,
"Gradient type %s generated for complex-valued "
"tensor %s with type %s must be real" % (dtypes.as_dtype(
grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
+ elif y.dtype == dtypes.variant:
+ if grad_y.dtype != dtypes.variant:
+ raise TypeError(
+ "Gradient type %s generated for variant "
+ "tensor %s with type %s must be variant" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
else:
raise TypeError(
"Tensor %s with type %s must be numeric "
@@ -298,7 +304,7 @@ def _IsBackpropagatable(tensor):
if _IsTrainable(tensor):
return True
dtype = dtypes.as_dtype(tensor.dtype)
- return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant)
+ return dtype.base_dtype in (dtypes.bfloat16, dtypes.variant)
def _VerifyGeneratedGradients(grads, op):
@@ -341,7 +347,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
stop_ops = set()
for op in from_ops:
is_stop_op = True
- for inp in _Inputs(op, xs):
+ for inp in _NonEagerInputs(op, xs):
if pending_count[inp.op] > 0:
is_stop_op = False
break
@@ -365,10 +371,10 @@ def _IsPartitionedCall(op):
return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
-def _SymGrad(op, out_grads, xs):
+def _SymGrad(op, out_grads):
"""Backprop through a function call node op given its outputs' gradients."""
- f_in = [x for x in _Inputs(op, xs)] + out_grads
- f_types = [x.dtype for x in _Inputs(op, xs)]
+ f_in = [x for x in op.inputs] + out_grads
+ f_types = [x.dtype for x in op.inputs]
f = attr_value_pb2.NameAttrList()
if _IsPartitionedCall(op):
f.name = op.get_attr("f").name
@@ -435,7 +441,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
if curr_op in from_ops:
target_op = curr_op
break
- queue.extend(t.op for t in _Inputs(curr_op, xs))
+ queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
assert target_op
raise ValueError(
"Cannot compute gradient inside while loop with respect to op '%s'. "
@@ -468,7 +474,8 @@ def _MaybeCaptured(t):
A tensor, potentially from a different Graph/_function.FuncGraph.
"""
# pylint: disable=protected-access
- if _IsFunction(t.op.graph) and t.op.type == "Placeholder":
+ if (not isinstance(t, ops.EagerTensor) and
+ _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
for input_t, placeholder_t in _Captures(t.op.graph).items():
if t == placeholder_t:
return _MaybeCaptured(input_t)
@@ -478,9 +485,12 @@ def _MaybeCaptured(t):
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
# _GradientsHelper a class with xs as a member variable.
-def _Inputs(op, xs):
+def _NonEagerInputs(op, xs):
"""Returns the inputs of op, crossing closure boundaries where necessary.
+ Does not return any captured EagerTensors, i.e., the number of tensors
+ returned may be less than than the actual number of inputs.
+
Args:
op: Operation
xs: list of Tensors we are differentiating w.r.t.
@@ -491,12 +501,19 @@ def _Inputs(op, xs):
captured inputs.
"""
if _IsFunction(op.graph): # pylint: disable=protected-access
- # If we're differentiating w.r.t. `t`, do not attempt to traverse through it
- # to a captured value. The algorithm needs to "see" `t` in this case, even
- # if it's a function input for a captured value, whereas usually we'd like
- # to traverse through these closures as if the captured value was the direct
- # input to op.
- return [t if (t in xs) else _MaybeCaptured(t) for t in op.inputs]
+ inputs = []
+ for t in op.inputs:
+ # If we're differentiating w.r.t. `t`, do not attempt to traverse through
+ # it to a captured value. The algorithm needs to "see" `t` in this case,
+ # even if it's a function input for a captured value, whereas usually we'd
+ # like to traverse through these closures as if the captured value was the
+ # direct input to op.
+ if t not in xs:
+ t = _MaybeCaptured(t)
+ # Skip captured eager inputs.
+ if isinstance(t, ops.EagerTensor): continue
+ inputs.append(t)
+ return inputs
else:
return op.inputs
@@ -799,7 +816,7 @@ def _GradientsHelper(ys,
# For function call ops, we add a 'SymbolicGradient'
# node to the graph to compute gradients.
in_grads = _MaybeCompile(grad_scope, op, func_call,
- lambda: _SymGrad(op, out_grads, xs))
+ lambda: _SymGrad(op, out_grads))
in_grads = _AsList(in_grads)
_VerifyGeneratedGradients(in_grads, op)
if gate_gradients and len([x for x in in_grads
@@ -814,8 +831,9 @@ def _GradientsHelper(ys,
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards.
- in_grads = [None] * len(_Inputs(op, xs))
- for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
+ in_grads = [None] * len(_NonEagerInputs(op, xs))
+ for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
+ in_grads)):
if in_grad is not None:
if (isinstance(in_grad, ops.Tensor) and
t_in.dtype != dtypes.resource):
@@ -856,7 +874,7 @@ def _HasAnyNotNoneGrads(grads, op):
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
xs):
"""Update pending count for the inputs of op and enqueue ready ops."""
- for x in _Inputs(op, xs):
+ for x in _NonEagerInputs(op, xs):
pending_count[x.op] -= 1
ready = (pending_count[x.op] == 0)
if loop_state and not ready:
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 3759d8a543..4f6e5dc473 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
from tensorflow.python.ops import gradients
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
@@ -530,6 +531,24 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess:
self.assertEqual(sess.run(z_grad), 3.0)
+ def testCapturedEagerTensors(self):
+ # Test that we can handle captured eager tensors unrelated to the gradient
+ # computation (i.e. we need to ignore them).
+ # TODO(skyewm): make it an error if you try to take the gradient wrt a
+ # captured EagerTensor
+ with context.eager_mode():
+ c = constant_op.constant(2.0, name="c")
+
+ @function.defun
+ def Foo():
+ x = constant_op.constant(10.0, name="x")
+ y = math_ops.multiply(x, c, name="y")
+ z = math_ops.multiply(y, 3.0, name="z")
+ g = gradients_impl.gradients(z, x)
+ return g[0]
+
+ self.assertEqual(Foo().numpy(), 6.0)
+
class StopGradientTest(test_util.TensorFlowTestCase):
@@ -1004,5 +1023,25 @@ class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
self._assert_indexed_slices_equal(total, result)
+class TensorListGradientsTest(test_util.TensorFlowTestCase):
+
+ def testDefaultGradYs(self):
+ with ops.Graph().as_default():
+ tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ a = constant(1.0)
+ tl = list_ops.tensor_list_push_back(tl, a)
+
+ grad_tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
+
+ grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
+ with self.cached_session() as sess:
+ self.assertEquals(sess.run(grad), 5.)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index de260f3140..1c75aab578 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -29,7 +29,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@@ -301,21 +300,21 @@ def random_flip_left_right(image, seed=None):
def _random_flip(image, flip_index, seed, scope_name):
"""Randomly (50% chance) flip an image along axis `flip_index`.
- Args:
- image: 4-D Tensor of shape `[batch, height, width, channels]` or
- 3-D Tensor of shape `[height, width, channels]`.
- flip_index: The dimension along which to flip the image.
- Vertical: 0, Horizontal: 1
- seed: A Python integer. Used to create a random seed. See
- `tf.set_random_seed`
- for behavior.
- scope_name: Name of the scope in which the ops are added.
- Returns:
- A tensor of the same type and shape as `image`.
+ Args:
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ flip_index: Dimension along which to flip image. Vertical: 0, Horizontal: 1
+ seed: A Python integer. Used to create a random seed. See
+ `tf.set_random_seed`
+ for behavior.
+ scope_name: Name of the scope in which the ops are added.
- Raises:
- ValueError: if the shape of `image` not supported.
+ Returns:
+ A tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
@@ -330,19 +329,18 @@ def _random_flip(image, flip_index, seed, scope_name):
lambda: image,
name=scope
)
- if isinstance(result, tuple):
- result = result[0] # TODO(b/111124878) remove this logic (CondV2).
return fix_image_flip_shape(image, result)
elif shape.ndims == 4:
+ batch_size = array_ops.shape(image)[0]
uniform_random = random_ops.random_uniform(
- [array_ops.shape(image)[0]], 0, 1.0, seed=seed
+ [batch_size], 0, 1.0, seed=seed
)
- mirror_cond = math_ops.less(uniform_random, .5)
- return array_ops.where(
- mirror_cond,
- image,
- functional_ops.map_fn(lambda x: array_ops.reverse(x, [flip_index]), image, dtype=image.dtype)
+ flips = math_ops.round(
+ array_ops.reshape(uniform_random, [batch_size, 1, 1, 1])
)
+ flips = math_ops.cast(flips, image.dtype)
+ flipped_input = array_ops.reverse(image, [flip_index + 1])
+ return flips * flipped_input + (1 - flips) * image
else:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
@@ -1029,10 +1027,10 @@ def resize_images(images,
scale_factor_width = (math_ops.to_float(new_width_const) /
math_ops.to_float(current_width))
scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width)
- scaled_height_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_height))
- scaled_width_const = math_ops.to_int32(scale_factor *
- math_ops.to_float(current_width))
+ scaled_height_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_height)))
+ scaled_width_const = math_ops.to_int32(
+ math_ops.round(scale_factor * math_ops.to_float(current_width)))
# NOTE: Reset the size and other constants used later.
size = ops.convert_to_tensor([scaled_height_const, scaled_width_const],
@@ -1176,7 +1174,7 @@ def resize_image_with_pad(image,
@tf_export('image.per_image_standardization')
def per_image_standardization(image):
- """Linearly scales `image` to have zero mean and unit norm.
+ """Linearly scales `image` to have zero mean and unit variance.
This op computes `(x - mean) / adjusted_stddev`, where `mean` is the average
of all values in image, and
@@ -1379,7 +1377,7 @@ def adjust_gamma(image, gamma=1, gain=1):
[1] http://en.wikipedia.org/wiki/Gamma_correction
"""
- with ops.op_scope([image, gamma, gain], None, 'adjust_gamma'):
+ with ops.name_scope(None, 'adjust_gamma', [image, gamma, gain]) as name:
# Convert pixel value to DT_FLOAT for computing adjusted image.
img = ops.convert_to_tensor(image, name='img', dtype=dtypes.float32)
# Keep image dtype for computing the scale of corresponding dtype.
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 795e6bbc3e..35fdee4fad 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -2687,6 +2687,12 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3])
+ def testPreserveAspectRatioSquare(self):
+ x_shape = [299, 299, 3]
+ x = np.random.uniform(size=x_shape)
+
+ self._assertResizeCheckShape(x, x_shape, [320, 320], [320, 320, 3])
+
class ResizeImageWithPadTest(test_util.TensorFlowTestCase):
@@ -3667,7 +3673,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# Note: There are multiple versions of non_max_suppression v2, v3, v4.
# gen_image_ops.non_max_suppression_v2:
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3677,7 +3683,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
self.assertAllClose(selected_indices, [3, 0, 5])
# image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3.
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
@@ -3688,7 +3694,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
# gen_image_ops.non_max_suppression_v4.
score_threshold = float('-inf')
for dtype in [np.float16, np.float32]:
- with self.test_session():
+ with self.cached_session():
boxes = constant_op.constant(boxes_np, dtype=dtype)
scores = constant_op.constant(scores_np, dtype=dtype)
max_output_size = constant_op.constant(max_output_size_np)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/python/ops/linalg/linear_operator_addition.py
index 86130a2c07..86130a2c07 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py
+++ b/tensorflow/python/ops/linalg/linear_operator_addition.py
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index c367ed25ad..021ef47383 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -160,20 +160,20 @@ class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
`block_depth = 1` means `A` is symmetric circulant. For example,
```
- A = |x y z y|
- |y x y z|
- |z y x y|
- |y z y x|
+ A = |w z y x|
+ |x w z y|
+ |y x w z|
+ |z y x w|
```
`block_depth = 2` means `A` is block symmetric circulant with symemtric
- circulant blocks. For example, with `X`, `Y`, `Z` symmetric circulant,
+ circulant blocks. For example, with `W`, `X`, `Y`, `Z` symmetric circulant,
```
- A = |X Y Z Y|
- |Y X Y Z|
- |Z Y X Y|
- |Y Z Y X|
+ A = |W Z Y X|
+ |X W Z Y|
+ |Y X W Z|
+ |Z Y X W|
```
`block_depth = 3` means `A` is block symmetric circulant with block
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 78c85db557..76d659f109 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -184,7 +184,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -199,7 +199,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -215,7 +215,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -240,7 +240,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -283,7 +283,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
for adjoint in self._adjoint_options:
for adjoint_arg in self._adjoint_arg_options:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -319,7 +319,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -335,7 +335,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
@@ -353,7 +353,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for use_placeholder in self._use_placeholder_options:
for build_info in self._operator_build_infos:
for dtype in self._dtypes_to_test:
- with self.test_session(graph=ops.Graph()) as sess:
+ with self.session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index df41933f8a..4c53f33af1 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -19,13 +19,24 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import pprint
+import random
+import sys
+
+import six
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import gen_logging_ops
+from tensorflow.python.ops import string_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_logging_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -40,7 +51,32 @@ from tensorflow.python.util.tf_export import tf_export
# For users with Python 3 or Python 2.7
# with `from __future__ import print_function`, we could also allow lowercase.
# See https://github.com/tensorflow/tensorflow/issues/18053
-@tf_export("Print")
+
+
+# pylint: disable=invalid-name
+@deprecated("2018-08-20", "Use tf.print instead of tf.Print. Note that "
+ "tf.print returns a no-output operator that directly "
+ "prints the output. Outside of defuns or eager mode, "
+ "this operator will not be executed unless it is "
+ "directly specified in session.run or used as a "
+ "control dependency for other operators. This is "
+ "only a concern in graph mode. Below is an example "
+ "of how to ensure tf.print executes in graph mode:\n"
+ """```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print(tensor)
+ with tf.control_dependencies([print_op]):
+ out = tf.add(tensor, tensor)
+ sess.run(out)
+ ```
+Additionally, to use tf.print in python 2.7, users must make sure to import
+the following:
+
+ `from __future__ import print_function`
+""")
+@tf_export(v1=["Print"])
def Print(input_, data, message=None, first_n=None, summarize=None,
name=None):
"""Prints a list of tensors.
@@ -66,6 +102,228 @@ def Print(input_, data, message=None, first_n=None, summarize=None,
A `Tensor`. Has the same type and contents as `input_`.
"""
return gen_logging_ops._print(input_, data, message, first_n, summarize, name)
+# pylint: enable=invalid-name
+
+
+def _generate_placeholder_string(x, default_placeholder="{}"):
+ """Generate and return a string that does not appear in `x`."""
+ placeholder = default_placeholder
+ rng = random.Random(5)
+ while placeholder in x:
+ placeholder = placeholder + str(rng.randint(0, 9))
+ return placeholder
+
+
+# Temporarily disable pylint g-doc-args error to allow giving more context
+# about what the kwargs are.
+# Because we are using arbitrary-length positional arguments, python 2
+# does not support explicitly specifying the keyword arguments in the
+# function definition.
+# pylint: disable=g-doc-args
+@tf_export("print")
+def print_v2(*inputs, **kwargs):
+ """Print the specified inputs.
+
+ Returns an operator that prints the specified inputs to a desired
+ output stream or logging level. The inputs may be dense or sparse Tensors,
+ primitive python objects, data structures that contain Tensors, and printable
+ python objects. Printed tensors will recursively show the first and last
+ `summarize` elements of each dimension.
+
+ With eager execution enabled and/or inside a `tf.contrib.eager.defun` this
+ operator will automatically execute, and users only need to call `tf.print`
+ without using the return value. When constructing graphs outside of a
+ `tf.contrib.eager.defun`, one must either include the returned op
+ in the input to `session.run`, or use the operator as a control dependency for
+ executed ops by specifying `with tf.control_dependencies([print_op])`.
+
+ @compatibility(python2)
+ In python 2.7, make sure to import the following:
+ `from __future__ import print_function`
+ @end_compatibility
+
+ Example:
+ Single-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Multi-input usage:
+ ```python
+ tf.enable_eager_execution()
+ tensor = tf.range(10)
+ tf.print("tensors:", tensor, {2: tensor * 2}, output_stream=sys.stdout)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Usage in a defun:
+ ```python
+ tf.enable_eager_execution()
+
+ @tf.contrib.eager.defun
+ def f():
+ tensor = tf.range(10)
+ tf.print(tensor, output_stream=sys.stderr)
+ return tensor
+
+ range_tensor = f()
+ ```
+ (This prints "[0 1 2 ... 7 8 9]" to sys.stderr)
+
+ Usage when constructing graphs:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ print_op = tf.print("tensors:", tensor, {2: tensor * 2},
+ output_stream=sys.stdout)
+ with tf.control_dependencies([print_op]):
+ tripled_tensor = tensor * 3
+ sess.run(tripled_tensor)
+ ```
+ (This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
+ sys.stdout)
+
+ Note: This op is only partially compatible with Jupyter notebooks and colabs.
+ Because it prints to the C++ standard out / standard error, this will go
+ in the notebook kernel's console output, not in the notebook cell output.
+
+ Args:
+ *inputs: Positional arguments that are the inputs to print. Inputs in the
+ printed output will be separated by spaces. Inputs may be python
+ primitives, tensors, data structures such as dicts and lists that
+ may contain tensors (with the data structures possibly nested in
+ arbitrary ways), and printable python objects.
+ output_stream: The output stream or logging level to print to. Defaults to
+ sys.stderr, but sys.stdout, tf.logging.info, tf.logging.warning, and
+ tf.logging.error are also supported.
+ summarize: The first and last `summarize` elements within each dimension are
+ recursively printed per Tensor. If None, then the first 3 and last 3
+ elements of each dimension are printed for each tensor. If set to -1, it
+ will print all elements of every tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ A print operator that prints the specified inputs in the specified output
+ stream or logging level.
+
+ Raises:
+ ValueError: If an unsupported output stream is specified.
+ """
+ # Because we are using arbitrary-length positional arguments, python 2
+ # does not support explicitly specifying the keyword arguments in the
+ # function definition. So, we manually get the keyword arguments w/ default
+ # values here.
+ output_stream = kwargs.pop("output_stream", sys.stderr)
+ name = kwargs.pop("name", None)
+ summarize = kwargs.pop("summarize", 3)
+ if kwargs:
+ raise ValueError("Unrecognized keyword arguments for tf.print: %s" % kwargs)
+ format_name = None
+ if name:
+ format_name = name + "_format"
+
+ # Match the C++ string constants representing the different output streams.
+ # Keep this updated!
+ output_stream_to_constant = {
+ sys.stdout: "stdout",
+ sys.stderr: "stderr",
+ tf_logging.INFO: "log(info)",
+ tf_logging.info: "log(info)",
+ tf_logging.WARN: "log(warning)",
+ tf_logging.warning: "log(warning)",
+ tf_logging.warn: "log(warning)",
+ tf_logging.ERROR: "log(error)",
+ tf_logging.error: "log(error)",
+ }
+
+ output_stream_string = output_stream_to_constant.get(output_stream)
+ if not output_stream_string:
+ raise ValueError(
+ "Unsupported output stream or logging level " +
+ str(output_stream) + ". Supported streams are sys.stdout, "
+ "sys.stderr, tf.logging.info, "
+ "tf.logging.warning, tf.logging.error")
+
+ # If we are only printing a single string scalar, there is no need to format
+ if (len(inputs) == 1 and tensor_util.is_tensor(inputs[0])
+ and (not isinstance(inputs[0], sparse_tensor.SparseTensor))
+ and inputs[0].shape and (inputs[0].dtype == dtypes.string)):
+ formatted_string = inputs[0]
+ # Otherwise, we construct an appropriate template for the tensors we are
+ # printing, and format the template using those tensors.
+ else:
+ # For each input to this print function, we extract any nested tensors,
+ # and construct an appropriate template to format representing the
+ # printed input.
+ templates = []
+ tensors = []
+ tensor_free_structure = nest.map_structure(
+ lambda x: "" if tensor_util.is_tensor(x) else x,
+ inputs)
+ tensor_free_template = " ".join(pprint.pformat(x)
+ for x in tensor_free_structure)
+ placeholder = _generate_placeholder_string(tensor_free_template)
+
+ for input_ in inputs:
+ placeholders = []
+ # Use the nest utilities to flatten & process any nested elements in this
+ # input. The placeholder for a tensor in the template should be the
+ # placeholder string, and the placeholder for a non-tensor can just be
+ # the printed value of the non-tensor itself.
+ for x in nest.flatten(input_):
+ # support sparse tensors
+ if isinstance(x, sparse_tensor.SparseTensor):
+ tensors.extend([x.indices, x.values, x.dense_shape])
+ placeholders.append(
+ "SparseTensor(indices={}, values={}, shape={})".format(
+ placeholder, placeholder, placeholder)
+ )
+ elif tensor_util.is_tensor(x):
+ tensors.append(x)
+ placeholders.append(placeholder)
+ else:
+ placeholders.append(x)
+
+ if isinstance(input_, six.string_types):
+ # If the current input to format/print is a normal string, that string
+ # can act as the template.
+ cur_template = input_
+ else:
+ # We pack the placeholders into a data structure that matches the
+ # input data structure format, then format that data structure
+ # into a string template.
+ #
+ # NOTE: We must use pprint.pformat here for building the template for
+ # unordered data structures such as `dict`, because `str` doesn't
+ # guarantee orderings, while pprint prints in sorted order. pprint
+ # will match the ordering of `nest.flatten`.
+ # This even works when nest.flatten reorders OrderedDicts, because
+ # pprint is printing *after* the OrderedDicts have been reordered.
+ cur_template = pprint.pformat(
+ nest.pack_sequence_as(input_, placeholders))
+ templates.append(cur_template)
+
+ # We join the templates for the various inputs into a single larger
+ # template. We also remove all quotes surrounding the placeholders, so that
+ # the formatted/printed output will not contain quotes around tensors.
+ # (example of where these quotes might appear: if we have added a
+ # placeholder string into a list, then pretty-formatted that list)
+ template = " ".join(templates)
+ template = template.replace("'" + placeholder + "'", placeholder)
+ formatted_string = string_ops.string_format(
+ inputs=tensors, template=template, placeholder=placeholder,
+ summarize=summarize,
+ name=format_name)
+
+ return gen_logging_ops.print_v2(formatted_string,
+ output_stream=output_stream_string,
+ name=name)
+# pylint: enable=g-doc-args
@ops.RegisterGradient("Print")
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 561a341cf3..5443699ddd 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -422,7 +422,7 @@ class TextFileInitializer(TableInitializerBase):
* `palmer -> 30`
```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
+ table = tf.lookup.HashTable(tf.lookup.TextFileInitializer(
"test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1)
...
table.init.run()
@@ -435,9 +435,9 @@ class TextFileInitializer(TableInitializerBase):
* `palmer 30 -> 2`
```python
- table = tf.contrib.lookup.HashTable(tf.contrib.lookup.TextFileInitializer(
- "test.txt", tf.string, tf.contrib.lookup.TextFileIndex.WHOLE_LINE,
- tf.int64, tf.contrib.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
+ table = tf.lookup.HashTable(tf.lookup.TextFileInitializer(
+ "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
+ tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
...
table.init.run()
```
@@ -953,7 +953,7 @@ def index_table_from_file(vocabulary_file=None,
```python
features = tf.constant(["emerson", "lake", "and", "palmer"])
- table = tf.contrib.lookup.index_table_from_file(
+ table = tf.lookup.index_table_from_file(
vocabulary_file="test.txt", num_oov_buckets=1)
ids = table.lookup(features)
...
@@ -1054,21 +1054,21 @@ def index_table_from_tensor(vocabulary_list,
Any lookup of an out-of-vocabulary token will return a bucket ID based on its
hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
- `default_value`.
- The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`.
+ `default_value`. The bucket ID range is
+ `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
The underlying table must be initialized by calling
`tf.tables_initializer.run()` or `table.init.run()` once.
- Elements in `mapping` cannot have duplicates, otherwise when executing the
- table initializer op, it will throw a `FailedPreconditionError`.
+ Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
+ the table initializer op, it will throw a `FailedPreconditionError`.
Sample Usages:
```python
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
- table = tf.contrib.lookup.index_table_from_tensor(
- mapping=vocabulary_list, num_oov_buckets=1, default_value=-1)
+ table = tf.lookup.index_table_from_tensor(
+ vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer"])
ids = table.lookup(features)
...
@@ -1093,7 +1093,7 @@ def index_table_from_tensor(vocabulary_list,
The lookup table to map an input `Tensor` to index `int64` `Tensor`.
Raises:
- ValueError: If `mapping` is invalid.
+ ValueError: If `vocabulary_list` is invalid.
ValueError: If `num_oov_buckets` is negative.
"""
if vocabulary_list is None:
@@ -1185,7 +1185,7 @@ def index_to_string_table_from_file(vocabulary_file,
```python
indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_file(
+ table = tf.lookup.index_to_string_table_from_file(
vocabulary_file="test.txt", default_value="UNKNOWN")
values = table.lookup(indices)
...
@@ -1250,25 +1250,25 @@ def index_to_string_table_from_tensor(vocabulary_list,
"""Returns a lookup table that maps a `Tensor` of indices into strings.
This operation constructs a lookup table to map int64 indices into string
- values. The mapping is initialized from a string `mapping` 1-D `Tensor` where
- each element is a value and the corresponding index within the tensor is the
- key.
+ values. The mapping is initialized from a string `vocabulary_list` 1-D
+ `Tensor` where each element is a value and the corresponding index within the
+ tensor is the key.
- Any input which does not have a corresponding index in 'mapping'
+ Any input which does not have a corresponding index in 'vocabulary_list'
(an out-of-vocabulary entry) is assigned the `default_value`
The underlying table must be initialized by calling
`tf.tables_initializer.run()` or `table.init.run()` once.
- Elements in `mapping` cannot have duplicates, otherwise when executing the
- table initializer op, it will throw a `FailedPreconditionError`.
+ Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
+ the table initializer op, it will throw a `FailedPreconditionError`.
Sample Usages:
```python
vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
indices = tf.constant([1, 5], tf.int64)
- table = tf.contrib.lookup.index_to_string_table_from_tensor(
+ table = tf.lookup.index_to_string_table_from_tensor(
vocabulary_list, default_value="UNKNOWN")
values = table.lookup(indices)
...
diff --git a/tensorflow/python/ops/losses/util_test.py b/tensorflow/python/ops/losses/util_test.py
index 7fa7a41fca..df2e60e2e4 100644
--- a/tensorflow/python/ops/losses/util_test.py
+++ b/tensorflow/python/ops/losses/util_test.py
@@ -28,7 +28,7 @@ class LossesUtilTest(test.TestCase):
def testGetRegularizationLoss(self):
# Empty regularization collection should evaluate to 0.0.
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0.0, util.get_regularization_loss().eval())
# Loss should sum.
@@ -36,14 +36,14 @@ class LossesUtilTest(test.TestCase):
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(5.0, util.get_regularization_loss().eval())
# Check scope capture mechanism.
with ops.name_scope('scope1'):
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(-1.0))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(-1.0, util.get_regularization_loss('scope1').eval())
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 33e7a5533b..f57abf6704 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1088,9 +1088,6 @@ def floordiv(x, y, name=None):
`x // y` floor division in Python 3 and in Python 2.7 with
`from __future__ import division`.
- Note that for efficiency, `floordiv` uses C semantics for negative numbers
- (unlike Python and Numpy).
-
`x` and `y` must have the same type, and the result will have the same type
as well.
@@ -1100,7 +1097,7 @@ def floordiv(x, y, name=None):
name: A name for the operation (optional).
Returns:
- `x / y` rounded down (except possibly towards zero for negative integers).
+ `x / y` rounded down.
Raises:
TypeError: If the inputs are complex.
@@ -2901,21 +2898,23 @@ def tensordot(a, b, axes, name=None):
shape_a = a.get_shape().as_list()
axes = [i if i >= 0 else i + len(shape_a) for i in axes]
free = [i for i in xrange(len(shape_a)) if i not in axes]
- free_dims_static = [shape_a[i] for i in free]
+ axes_dims = [shape_a[i] for i in axes]
+ free_dims = [shape_a[i] for i in free]
+ free_dims_static = free_dims
+ axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+ free = ops.convert_to_tensor(free, dtype=dtypes.int32, name="free")
+ shape_a = array_ops.shape(a)
else:
free_dims_static = None
- shape_a = array_ops.shape(a)
- rank_a = array_ops.rank(a)
- axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
- axes = cast(axes >= 0, dtypes.int32) * axes + cast(
- axes < 0, dtypes.int32) * (
- axes + rank_a)
- free, _ = array_ops.setdiff1d(range(rank_a), axes)
+ shape_a = array_ops.shape(a)
+ rank_a = array_ops.rank(a)
+ axes = ops.convert_to_tensor(axes, dtype=dtypes.int32, name="axes")
+ axes = array_ops.where(axes >= 0, axes, axes + rank_a)
+ free, _ = array_ops.setdiff1d(range(rank_a), axes)
free_dims = array_ops.gather(shape_a, free)
axes_dims = array_ops.gather(shape_a, axes)
prod_free_dims = reduce_prod(free_dims)
prod_axes_dims = reduce_prod(axes_dims)
- perm = array_ops.concat([axes_dims, free_dims], 0)
if flipped:
perm = array_ops.concat([axes, free], 0)
new_shape = array_ops.stack([prod_axes_dims, prod_free_dims])
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 2861f40586..3f64f0af9a 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,7 +22,6 @@ import numbers
import numpy as np
-from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
@@ -1672,47 +1671,24 @@ def _softmax(logits, compute_op, dim=-1, name=None):
shape = logits.get_shape()
is_last_dim = (dim is -1) or (dim == shape.ndims - 1)
- # TODO(phawkins): remove after 2018/8/27 and simplify this code.
- softmax_accepts_r1_or_greater = compat.forward_compatible(2018, 8, 27)
- reshape_required = (not softmax_accepts_r1_or_greater) and shape.ndims != 2
if is_last_dim:
- if reshape_required:
- # If dim is the last dimension, simply reshape the logits to a matrix and
- # apply the internal softmax.
- input_shape = array_ops.shape(logits)
- logits = _flatten_outer_dims(logits)
- output = compute_op(logits)
- output = array_ops.reshape(output, input_shape, name=name)
- return output
return compute_op(logits, name=name)
- # If dim is not the last dimension, we have to do a reshape and transpose so
- # that we can still perform softmax on its last dimension.
+ # If dim is not the last dimension, we have to do a transpose so that we can
+ # still perform softmax on its last dimension.
# Swap logits' dimension of dim and its last dimension.
input_rank = array_ops.rank(logits)
dim_axis = dim % shape.ndims
logits = _swap_axis(logits, dim_axis, math_ops.subtract(input_rank, 1))
- shape_after_swap = array_ops.shape(logits)
- if reshape_required:
- # Reshape logits into a matrix.
- logits = _flatten_outer_dims(logits)
-
- # Do the actual softmax on its last dimension.
- output = compute_op(logits)
-
- # Transform back the output tensor.
- output = array_ops.reshape(output, shape_after_swap)
- else:
- # Do the actual softmax on its last dimension.
- output = compute_op(logits)
+ # Do the actual softmax on its last dimension.
+ output = compute_op(logits)
output = _swap_axis(
output, dim_axis, math_ops.subtract(input_rank, 1), name=name)
- # Make shape inference work since reshape and transpose may erase its static
- # shape.
+ # Make shape inference work since transpose may erase its static shape.
output.set_shape(shape)
return output
diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD
index 015181af47..07fc9433a2 100644
--- a/tensorflow/python/ops/parallel_for/BUILD
+++ b/tensorflow/python/ops/parallel_for/BUILD
@@ -123,6 +123,8 @@ cuda_py_test(
"//third_party/py/numpy",
"//tensorflow/python:layers",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python/ops/losses",
],
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index d403b0c61a..6e276dee55 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -31,6 +31,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
+from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients as gradient_ops
@@ -300,28 +302,129 @@ class ArrayTest(PForTest):
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+class BitwiseTest(PForTest):
+
+ def test_unary_cwise(self):
+ for op in [bitwise_ops.invert]:
+ x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
+
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ return op(x1)
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.int32])
+
+ def test_binary_cwise(self):
+ binary_ops = [
+ bitwise_ops.bitwise_and,
+ bitwise_ops.bitwise_or,
+ bitwise_ops.bitwise_xor,
+ bitwise_ops.left_shift,
+ bitwise_ops.right_shift,
+ ]
+ for op in binary_ops:
+ x = random_ops.random_uniform([7, 3, 5], maxval=10, dtype=dtypes.int32)
+ y = random_ops.random_uniform([3, 5], maxval=10, dtype=dtypes.int32)
+
+ output_dtypes = []
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ y1 = array_ops.gather(y, i)
+ outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
+ del output_dtypes[:]
+ output_dtypes.extend([t.dtype for t in outputs])
+ return outputs
+ # pylint: enable=cell-var-from-loop
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
+
+
class MathTest(PForTest):
def test_unary_cwise_ops(self):
- for op in [
- math_ops.tanh, nn.relu, math_ops.sigmoid, math_ops.negative,
- math_ops.square
- ]:
+ complex_ops = [
+ math_ops.angle,
+ math_ops.imag,
+ math_ops.complex_abs,
+ math_ops.real,
+ math_ops.conj,
+ ]
+ real_ops = [
+ lambda x: math_ops.acosh(1 + math_ops.square(x)),
+ math_ops.abs,
+ math_ops.acos,
+ math_ops.asin,
+ math_ops.asinh,
+ math_ops.atan,
+ math_ops.atanh,
+ math_ops.bessel_i0e,
+ math_ops.bessel_i1e,
+ math_ops.cos,
+ math_ops.cosh,
+ math_ops.digamma,
+ math_ops.erf,
+ math_ops.erfc,
+ math_ops.exp,
+ math_ops.expm1,
+ math_ops.inv,
+ math_ops.is_finite,
+ math_ops.is_inf,
+ math_ops.lgamma,
+ math_ops.log,
+ math_ops.log1p,
+ math_ops.neg,
+ math_ops.negative,
+ math_ops.reciprocal,
+ math_ops.rint,
+ math_ops.round,
+ math_ops.rsqrt,
+ math_ops.sigmoid,
+ math_ops.sign,
+ math_ops.sin,
+ math_ops.sinh,
+ math_ops.sqrt,
+ math_ops.square,
+ math_ops.tan,
+ math_ops.tanh,
+ math_ops.tanh,
+ nn.elu,
+ nn.relu,
+ nn.relu6,
+ nn.selu,
+ nn.softplus,
+ nn.softsign,
+ ]
+ for op in complex_ops + real_ops:
x = random_ops.random_uniform([3, 5])
+ if op in complex_ops:
+ y = random_ops.random_uniform([3, 5])
+ x = math_ops.complex(x, y)
# pylint: disable=cell-var-from-loop
+ output_dtypes = []
def loop_fn(i):
x1 = array_ops.gather(x, i)
- y = op(x1)
- loss = math_ops.reduce_sum(y * y)
- return op(x), y, gradient_ops.gradients(loss, x1)
+ y1 = op(x1)
+ outputs = [op(x), y1]
+ if y1.dtype == dtypes.float32:
+ loss = math_ops.reduce_sum(y1 * y1)
+ grad = gradient_ops.gradients(loss, x1)
+ if grad and grad[0] is not None:
+ outputs.extend(grad)
+ del output_dtypes[:]
+ output_dtypes.extend([t.dtype for t in outputs])
+ return outputs
# pylint: enable=cell-var-from-loop
- self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
def test_unary_cwise_no_grad(self):
- for op in [math_ops.ceil, math_ops.floor, math_ops.logical_not]:
+ for op in [math_ops.ceil,
+ math_ops.floor,
+ math_ops.logical_not]:
x = random_ops.random_uniform([3, 5])
if op == math_ops.logical_not:
x = x > 0
@@ -336,33 +439,80 @@ class MathTest(PForTest):
def test_binary_cwise_ops(self):
logical_ops = [
- math_ops.logical_and, math_ops.logical_or, math_ops.logical_xor
- ]
- bool_ops = [
- math_ops.less, math_ops.less_equal, math_ops.greater,
- math_ops.greater_equal, math_ops.equal, math_ops.not_equal
+ math_ops.logical_and,
+ math_ops.logical_or,
+ math_ops.logical_xor
]
+
+ # Wrapper functions restricting the range of inputs of zeta and polygamma.
+ def safe_polygamma(x, y):
+ return math_ops.polygamma(
+ math_ops.round(clip_ops.clip_by_value(y, 1, 10)),
+ x * x + 1)
+
+ def safe_zeta(x, y):
+ return math_ops.zeta(x * x + 1, y * y)
+
float_ops = [
- math_ops.add, math_ops.subtract, math_ops.multiply, math_ops.divide,
- math_ops.maximum, math_ops.minimum
+ math_ops.add,
+ math_ops.add_v2,
+ math_ops.atan2,
+ math_ops.complex,
+ math_ops.div,
+ math_ops.divide,
+ math_ops.div_no_nan,
+ math_ops.equal,
+ math_ops.floor_div,
+ math_ops.floor_mod,
+ math_ops.greater,
+ math_ops.greater_equal,
+ math_ops.igamma,
+ math_ops.igammac,
+ math_ops.igamma_grad_a,
+ math_ops.less,
+ math_ops.less_equal,
+ math_ops.maximum,
+ math_ops.minimum,
+ math_ops.mod,
+ math_ops.multiply,
+ math_ops.not_equal,
+ math_ops.pow,
+ math_ops.squared_difference,
+ math_ops.subtract,
+ math_ops.truncate_mod,
+ safe_polygamma,
+ safe_zeta,
]
- for op in logical_ops + bool_ops + float_ops:
+ for op in logical_ops + float_ops:
x = random_ops.random_uniform([7, 3, 5])
y = random_ops.random_uniform([3, 5])
if op in logical_ops:
x = x > 0
y = y > 0
+ output_dtypes = []
# pylint: disable=cell-var-from-loop
def loop_fn(i):
x1 = array_ops.gather(x, i)
y1 = array_ops.gather(y, i)
- return op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)
-
+ outputs = [op(x, y), op(x1, y), op(x, y1), op(x1, y1), op(x1, x1)]
+ del output_dtypes[:]
+ output_dtypes.extend([t.dtype for t in outputs])
+ return outputs
# pylint: enable=cell-var-from-loop
- dtype = dtypes.float32 if op in float_ops else dtypes.bool
- self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtype] * 5)
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
+
+ def test_approximate_equal(self):
+ x = random_ops.random_uniform([3, 5])
+ y = random_ops.random_uniform([3, 5])
+
+ def loop_fn(i):
+ x1 = array_ops.gather(x, i)
+ y1 = array_ops.gather(y, i)
+ return math_ops.approximate_equal(x1, y1)
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.bool])
def test_addn(self):
x = random_ops.random_uniform([2, 3, 5])
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
index 460de0a97f..1f026b3660 100644
--- a/tensorflow/python/ops/parallel_for/gradients.py
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -42,6 +42,7 @@ def jacobian(output, inputs, use_pfor=True):
[y_1, ..., y_n, x_1, ..., x_m].
"""
flat_inputs = nest.flatten(inputs)
+ output_tensor_shape = output.shape
output_shape = array_ops.shape(output)
output = array_ops.reshape(output, [-1])
@@ -65,6 +66,7 @@ def jacobian(output, inputs, use_pfor=True):
new_shape = array_ops.concat(
[output_shape, array_ops.shape(out)[1:]], axis=0)
out = array_ops.reshape(out, new_shape)
+ out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape))
pfor_outputs[i] = out
return nest.pack_sequence_as(inputs, pfor_outputs)
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index 628c6764cd..5467f55af6 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -32,6 +32,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import layers as tf_layers
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops as tf_control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -355,6 +357,30 @@ class GradientsTest(test.TestCase):
self.run_and_assert_equal(answer, jacobian_pfor)
self.run_and_assert_equal(answer, jacobian_while)
+ def test_jacobian_scan_shape(self):
+ # Shape x: [3, 4]
+ x = random_ops.random_uniform([3, 4])
+ elems = random_ops.random_uniform([6])
+ # Shape y: [6, 3, 4]
+ y = functional_ops.scan(lambda a, e: a + e, elems, initializer=x)
+ jacobian = gradients.jacobian(y, x)
+
+ expected_shape = [6, 3, 4, 3, 4]
+ self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
+ def test_jacobian_while_loop_shape(self):
+ # Shape x: [3, 4]
+ x = random_ops.random_uniform([3, 4])
+ _, y = tf_control_flow_ops.while_loop(lambda i, a: i > 5.,
+ lambda i, a: (i + 1, a + i),
+ (constant_op.constant(0.), x))
+ # Shape y: [2, 3]
+ y = y[:2, :3]
+ jacobian = gradients.jacobian(y, x)
+
+ expected_shape = [2, 3, 3, 4]
+ self.assertAllEqual(expected_shape, jacobian.shape.as_list())
+
def test_jacobian_unknown_shape(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=[None, None])
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index f9153b6d7d..e0f6d51881 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@@ -1922,37 +1923,114 @@ def _convert_cast(pfor_input):
return wrap(math_ops.cast(inp, dtype), True)
-# Note that ops handled here do not have attributes except "T", and hence don't
-# need extra arguments passed to the cwise_op call below.
+@RegisterPForWithArgs("Abs", math_ops.abs)
+@RegisterPForWithArgs("Acosh", math_ops.acosh)
+@RegisterPForWithArgs("Acos", math_ops.acos)
@RegisterPForWithArgs("Add", math_ops.add)
+@RegisterPForWithArgs("AddV2", math_ops.add_v2)
+@RegisterPForWithArgs("Angle", math_ops.angle)
+@RegisterPForWithArgs("Asinh", math_ops.asinh)
+@RegisterPForWithArgs("Asin", math_ops.asin)
+@RegisterPForWithArgs("Atan2", math_ops.atan2)
+@RegisterPForWithArgs("Atanh", math_ops.atanh)
+@RegisterPForWithArgs("Atan", math_ops.atan)
+@RegisterPForWithArgs("BesselI0e", math_ops.bessel_i0e)
+@RegisterPForWithArgs("BesselI1e", math_ops.bessel_i1e)
+@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and)
+@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or)
+@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor)
@RegisterPForWithArgs("Ceil", math_ops.ceil)
+@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs)
+@RegisterPForWithArgs("Complex", math_ops.complex)
+@RegisterPForWithArgs("Conj", math_ops.conj)
+@RegisterPForWithArgs("Cosh", math_ops.cosh)
+@RegisterPForWithArgs("Cos", math_ops.cos)
+@RegisterPForWithArgs("Digamma", math_ops.digamma)
+@RegisterPForWithArgs("Div", math_ops.div)
+@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan)
+@RegisterPForWithArgs("Elu", nn_ops.elu)
@RegisterPForWithArgs("Equal", math_ops.equal)
-@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
+@RegisterPForWithArgs("Erfc", math_ops.erfc)
+@RegisterPForWithArgs("Erf", math_ops.erf)
+@RegisterPForWithArgs("Expm1", math_ops.expm1)
+@RegisterPForWithArgs("Exp", math_ops.exp)
+@RegisterPForWithArgs("FloorDiv", math_ops.floor_div)
@RegisterPForWithArgs("Floor", math_ops.floor)
-@RegisterPForWithArgs("Greater", math_ops.greater)
+@RegisterPForWithArgs("FloorMod", math_ops.floor_mod)
@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal)
-@RegisterPForWithArgs("Less", math_ops.less)
+@RegisterPForWithArgs("Greater", math_ops.greater)
+@RegisterPForWithArgs("Igammac", math_ops.igammac)
+@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a)
+@RegisterPForWithArgs("Igamma", math_ops.igamma)
+@RegisterPForWithArgs("Imag", math_ops.imag)
+@RegisterPForWithArgs("Invert", bitwise_ops.invert)
+@RegisterPForWithArgs("Inv", math_ops.inv)
+@RegisterPForWithArgs("IsFinite", math_ops.is_finite)
+@RegisterPForWithArgs("IsInf", math_ops.is_inf)
+@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift)
@RegisterPForWithArgs("LessEqual", math_ops.less_equal)
-@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
+@RegisterPForWithArgs("Less", math_ops.less)
+@RegisterPForWithArgs("Lgamma", math_ops.lgamma)
+@RegisterPForWithArgs("Log1p", math_ops.log1p)
@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and)
@RegisterPForWithArgs("LogicalNot", math_ops.logical_not)
+@RegisterPForWithArgs("LogicalOr", math_ops.logical_or)
@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor)
+@RegisterPForWithArgs("Log", math_ops.log)
@RegisterPForWithArgs("Maximum", math_ops.maximum)
@RegisterPForWithArgs("Minimum", math_ops.minimum)
+@RegisterPForWithArgs("Mod", math_ops.mod)
@RegisterPForWithArgs("Mul", math_ops.multiply)
@RegisterPForWithArgs("Neg", math_ops.negative)
+@RegisterPForWithArgs("NotEqual", math_ops.not_equal)
+@RegisterPForWithArgs("Polygamma", math_ops.polygamma)
+@RegisterPForWithArgs("Pow", math_ops.pow)
@RegisterPForWithArgs("RealDiv", math_ops.divide)
+@RegisterPForWithArgs("Real", math_ops.real)
+@RegisterPForWithArgs("ReciprocalGrad", math_ops.reciprocal_grad)
+@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
+@RegisterPForWithArgs("Relu6", nn_ops.relu6)
@RegisterPForWithArgs("Relu", nn_ops.relu)
+@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
+@RegisterPForWithArgs("Rint", math_ops.rint)
+@RegisterPForWithArgs("Round", math_ops.round)
+@RegisterPForWithArgs("RsqrtGrad", math_ops.rsqrt_grad)
+@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
+@RegisterPForWithArgs("Selu", nn_ops.selu)
@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
+@RegisterPForWithArgs("Sign", math_ops.sign)
+@RegisterPForWithArgs("Sinh", math_ops.sinh)
+@RegisterPForWithArgs("Sin", math_ops.sin)
+@RegisterPForWithArgs("Softplus", nn_ops.softplus)
+@RegisterPForWithArgs("Softsign", nn_ops.softsign)
+@RegisterPForWithArgs("SqrtGrad", math_ops.sqrt_grad)
+@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
+@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
@RegisterPForWithArgs("Square", math_ops.square)
@RegisterPForWithArgs("Sub", math_ops.subtract)
@RegisterPForWithArgs("Tanh", math_ops.tanh)
+@RegisterPForWithArgs("Tan", math_ops.tan)
+@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div)
+@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod)
+@RegisterPForWithArgs("Zeta", math_ops.zeta)
def _convert_cwise(pfor_input, op_type, op_func):
- del op_type
+ # Note that ops handled here do not have attributes except "T" and "Tout", and
+ # hence don't need extra arguments passed to the cwise_op call below.
+ for attr in pfor_input.op.node_def.attr.keys():
+ assert attr in [u"T", u"Tout"], (op_type, attr)
pfor_input.expanddim_inputs_for_broadcast()
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
+@RegisterPFor("ApproximateEqual")
+def _convert_approximate_equal(pfor_input):
+ pfor_input.expanddim_inputs_for_broadcast()
+ x = pfor_input.input(0)[0]
+ y = pfor_input.input(1)[0]
+ tolerance = pfor_input.get_attr("tolerance")
+ return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True)
+
+
@RegisterPFor("Shape")
def _convert_shape(pfor_input):
out_type = pfor_input.get_attr("out_type")
@@ -2009,10 +2087,14 @@ def _convert_biasaddgrad(pfor_input):
# Some required ops are not exposed under the tf namespace. Hence relying on
# _create_op to create them.
+@RegisterPForWithArgs("EluGrad")
+@RegisterPForWithArgs("Relu6Grad")
@RegisterPForWithArgs("ReluGrad")
-@RegisterPForWithArgs("TanhGrad")
+@RegisterPForWithArgs("SeluGrad")
@RegisterPForWithArgs("SigmoidGrad")
@RegisterPForWithArgs("SoftplusGrad")
+@RegisterPForWithArgs("SoftsignGrad")
+@RegisterPForWithArgs("TanhGrad")
def _convert_grads(pfor_input, op_type, *args, **kw_args):
del args
del kw_args
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 8224097ac4..b3e03a0135 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -981,9 +981,10 @@ def parse_sequence_example(serialized,
name: A name for this operation (optional).
Returns:
- A tuple of two `dict`s, each mapping keys to `Tensor`s and `SparseTensor`s.
- The first dict contains the context key/values.
- The second dict contains the feature_list key/values.
+ A tuple of three `dict`s, each mapping keys to `Tensor`s and
+ `SparseTensor`s. The first dict contains the context key/values,
+ the second dict contains the feature_list key/values, and the final dict
+ contains the lengths of any dense feature_list features.
Raises:
ValueError: if any feature is invalid.
@@ -1584,7 +1585,8 @@ def decode_csv(records,
record_defaults: A list of `Tensor` objects with specific types.
Acceptable types are `float32`, `float64`, `int32`, `int64`, `string`.
One tensor per column of the input record, with either a
- scalar default value for that column or empty if the column is required.
+ scalar default value for that column or an empty vector if the column is
+ required.
field_delim: An optional `string`. Defaults to `","`.
char delimiter to separate fields in a record.
use_quote_delim: An optional `bool`. Defaults to `True`.
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 55c2eb5fa4..4a126e9d7a 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -48,14 +48,14 @@ def get_resource_handle_data(graph_op):
assert ops._USE_C_SHAPES # pylint: disable=protected-access
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
- handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType(
+ handle_data = pywrap_tensorflow.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
-def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
+def eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
"""Creates a variable handle with information to do shape inference."""
container = ops.get_default_graph()._container # pylint: disable=protected-access
if container is None:
@@ -397,61 +397,33 @@ class ResourceVariable(variables.RefVariable):
# When in eager mode use a uid for the shared_name, to prevent
# accidental sharing.
shared_name = "%s_%d" % (handle_name, ops.uid())
- if init_from_fn:
- # Use attr_scope and device(None) to simulate the behavior of
- # colocate_with when the variable we want to colocate with doesn't
- # yet exist.
- if self._in_graph_mode:
- attr = attr_value_pb2.AttrValue(
- list=attr_value_pb2.AttrValue.ListValue(
- s=[compat.as_bytes("loc:@%s" % handle_name)]))
- with ops.get_default_graph()._attr_scope({"_class": attr}):
- with ops.name_scope("Initializer"), ops.device(None):
- initial_value = ops.convert_to_tensor(
- initial_value(), name="initial_value", dtype=dtype)
- self._handle = _eager_safe_variable_handle(
- shape=initial_value.get_shape(),
- dtype=initial_value.dtype.base_dtype,
- shared_name=shared_name,
- name=name,
- graph_mode=self._in_graph_mode)
- self._shape = initial_value.get_shape()
- else:
- initial_value = initial_value()
- with ops.name_scope("Initializer"):
- initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
- self._handle = _eager_safe_variable_handle(
- shape=initial_value.get_shape(),
- dtype=initial_value.dtype.base_dtype,
- shared_name=shared_name,
- name=name,
- graph_mode=False)
- self._shape = initial_value.get_shape()
- # pylint: enable=protected-access
-
- # Or get the initial value from a Tensor or Python object.
- else:
- with ops.name_scope("Initializer"):
+ # Use attr_scope and device(None) to simulate the behavior of
+ # colocate_with when the variable we want to colocate with doesn't
+ # yet exist.
+ attr = attr_value_pb2.AttrValue(
+ list=attr_value_pb2.AttrValue.ListValue(
+ s=[compat.as_bytes("loc:@%s" % handle_name)]))
+ with ops.get_default_graph()._attr_scope({"_class": attr}):
+ with ops.name_scope("Initializer"), ops.device(None):
initial_value = ops.convert_to_tensor(
- initial_value, name="initial_value", dtype=dtype)
- # pylint: disable=protected-access
- if (self._in_graph_mode and initial_value is not None and
- initial_value.op._get_control_flow_context() is not None):
- raise ValueError(
- "Initializer for variable %s is from inside a control-flow "
- "construct, such as a loop or conditional. When creating a "
- "variable inside a loop or conditional, use a lambda as the "
- "initializer." % name)
- # pylint: enable=protected-access
- self._handle = _eager_safe_variable_handle(
+ initial_value() if init_from_fn else initial_value,
+ name="initial_value", dtype=dtype)
+ self._handle = eager_safe_variable_handle(
shape=initial_value.get_shape(),
dtype=initial_value.dtype.base_dtype,
shared_name=shared_name,
name=name,
graph_mode=self._in_graph_mode)
- self._shape = initial_value.get_shape()
-
+ self._shape = initial_value.shape
+ # pylint: disable=protected-access
+ if (self._in_graph_mode and initial_value is not None and
+ initial_value.op._get_control_flow_context() is not None):
+ raise ValueError(
+ "Initializer for variable %s is from inside a control-flow "
+ "construct, such as a loop or conditional. When creating a "
+ "variable inside a loop or conditional, use a lambda as the "
+ "initializer." % name)
+ # pylint: enable=protected-access
self._unique_id = shared_name
self._initial_value = initial_value if self._in_graph_mode else None
self._handle_name = handle_name + ":0"
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 5c00d929bf..5a3a5cc225 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -709,6 +709,10 @@ def _dynamic_rnn_loop(cell,
Raises:
ValueError: If the input depth cannot be inferred via shape inference
from the inputs.
+ ValueError: If time_step is not the same for all the elements in the
+ inputs.
+ ValueError: If batch_size is not the same for all the elements in the
+ inputs.
"""
state = initial_state
assert isinstance(parallel_iterations, int), "parallel_iterations must be int"
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index c11c9ccaae..43cca1a498 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -428,7 +428,7 @@ class BasicRNNCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % str(input_shape))
+ % str(inputs_shape))
input_depth = inputs_shape[-1]
self._kernel = self.add_variable(
@@ -525,7 +525,7 @@ class GRUCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % str(input_shape))
+ % str(inputs_shape))
input_depth = inputs_shape[-1]
self._gate_kernel = self.add_variable(
@@ -705,7 +705,7 @@ class BasicLSTMCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % str(input_shape))
+ % str(inputs_shape))
input_depth = inputs_shape[-1]
h_depth = self._num_units
@@ -908,7 +908,7 @@ class LSTMCell(LayerRNNCell):
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
- % str(input_shape))
+ % str(inputs_shape))
input_depth = inputs_shape[-1]
h_depth = self._num_units if self._num_proj is None else self._num_proj
@@ -954,7 +954,7 @@ class LSTMCell(LayerRNNCell):
"""Run one step of LSTM.
Args:
- inputs: input Tensor, 2D, `[batch, num_units].
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
state: if `state_is_tuple` is False, this must be a state Tensor,
`2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 29fefbe3a5..046a48d192 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -29,16 +29,19 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.util import compat as util_compat
# go/tf-wildcard-import
# pylint: disable=wildcard-import
+# pylint: disable=g-bad-import-order
from tensorflow.python.ops.gen_string_ops import *
+from tensorflow.python.util import compat as util_compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
+# pylint: enable=g-bad-import-order
# pylint: enable=wildcard-import
@@ -90,11 +93,6 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
Returns:
string `Tensor` of the same shape as `source` with specified replacements.
"""
- # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
- if not compat.forward_compatible(2018, 10, 10):
- return gen_string_ops.regex_replace(
- input=source, pattern=pattern,
- rewrite=rewrite, replace_global=replace_global)
if (isinstance(pattern, util_compat.bytes_or_text_types) and
isinstance(rewrite, util_compat.bytes_or_text_types)):
# When `pattern` and `rewrite` are static through the life of the op we can
@@ -108,6 +106,87 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
rewrite=rewrite, replace_global=replace_global)
+@tf_export("strings.format")
+def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
+ r"""Formats a string template using a list of tensors.
+
+ Formats a string template using a list of tensors, abbreviating tensors by
+ only printing the first and last `summarize` elements of each dimension
+ (recursively). If formatting only one tensor into a template, the tensor does
+ not have to be wrapped in a list.
+
+ Example:
+ Formatting a single-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor = tf.range(10)
+ formatted = tf.strings.format("tensor: {}, suffix", tensor)
+ out = sess.run(formatted)
+ expected = "tensor: [0 1 2 ... 7 8 9], suffix"
+
+ assert(out.decode() == expected)
+ ```
+
+ Formatting a multi-tensor template:
+ ```python
+ sess = tf.Session()
+ with sess.as_default():
+ tensor_one = tf.reshape(tf.range(100), [10, 10])
+ tensor_two = tf.range(10)
+ formatted = tf.strings.format("first: {}, second: {}, suffix",
+ (tensor_one, tensor_two))
+
+ out = sess.run(formatted)
+ expected = ("first: [[0 1 2 ... 7 8 9]\n"
+ " [10 11 12 ... 17 18 19]\n"
+ " [20 21 22 ... 27 28 29]\n"
+ " ...\n"
+ " [70 71 72 ... 77 78 79]\n"
+ " [80 81 82 ... 87 88 89]\n"
+ " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
+
+ assert(out.decode() == expected)
+ ```
+
+ Args:
+ template: A string template to format tensor values into.
+ inputs: A list of `Tensor` objects, or a single Tensor.
+ The list of tensors to format into the template string. If a solitary
+ tensor is passed in, the input tensor will automatically be wrapped as a
+ list.
+ placeholder: An optional `string`. Defaults to `{}`.
+ At each placeholder occurring in the template, a subsequent tensor
+ will be inserted.
+ summarize: An optional `int`. Defaults to `3`.
+ When formatting the tensors, show the first and last `summarize`
+ entries of each tensor dimension (recursively). If set to -1, all
+ elements of the tensor will be shown.
+ name: A name for the operation (optional).
+
+ Returns:
+ A scalar `Tensor` of type `string`.
+
+ Raises:
+ ValueError: if the number of placeholders does not match the number of
+ inputs.
+ """
+ # If there is only one tensor to format, we will automatically wrap it in a
+ # list to simplify the user experience
+ if tensor_util.is_tensor(inputs):
+ inputs = [inputs]
+ if template.count(placeholder) != len(inputs):
+ raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
+ " provided as input" % (template.count(placeholder),
+ len(inputs)))
+
+ return gen_string_ops.string_format(inputs,
+ template=template,
+ placeholder=placeholder,
+ summarize=summarize,
+ name=name)
+
+
@tf_export("string_split")
def string_split(source, delimiter=" ", skip_empty=True): # pylint: disable=invalid-name
"""Split elements of `source` based on `delimiter` into a `SparseTensor`.
@@ -251,6 +330,17 @@ def reduce_join(inputs, axis=None,
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
+
+# This wrapper provides backwards compatibility for code that predates the
+# unit argument and that passed 'name' as a positional argument.
+@tf_export("strings.length")
+def string_length(input, name=None, unit="BYTE"):
+ return gen_string_ops.string_length(input, unit=unit, name=name)
+
+
+string_length.__doc__ = gen_string_ops.string_length.__doc__
+
+
ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 94c7d88b5c..a404507627 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -234,6 +234,7 @@ def create_file_writer(logdir,
"""
if logdir is None:
return SummaryWriter(None, None)
+ logdir = str(logdir)
with ops.device("cpu:0"):
if max_queue is None:
max_queue = constant_op.constant(10)
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
new file mode 100644
index 0000000000..875be31602
--- /dev/null
+++ b/tensorflow/python/ops/while_v2.py
@@ -0,0 +1,580 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""while_v2 and gradient.
+
+This is a version of while_loop that emits a single While op, as well as the
+gradient function for While ops produced by while_loop. This will eventually
+replace the current tf.while_loop implementation once it reaches feature and
+performance parity.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
+from tensorflow.python.util import nest
+
+# pylint: disable=protected-access
+
+# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
+# control dependencies on external nodes with at least 1 output.
+# Another idea is to create const nodes outside the loop and add control edges
+# to them and then pass those in as data inputs. This should probably be
+# handled in the CapturingGraph itself.
+
+
+def while_loop(cond, body, loop_vars, name=None):
+ """Like tf.while_loop, except emits a single While op."""
+ if not name:
+ name = "while"
+
+ with ops.name_scope(name) as scope:
+ with ops.name_scope(None):
+ cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
+ body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
+
+ flattened_loop_vars = nest.flatten(loop_vars)
+ num_outputs = len(flattened_loop_vars)
+
+ # Add loop counter needed for computing gradients.
+ flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
+ ] + flattened_loop_vars
+
+ # Build a `cond` wrapper that can handle the extra counter loop_var.
+ def wrapped_cond(unused_loop_counter, *loop_vars):
+ return cond(*loop_vars)
+
+ cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
+ flattened_loop_vars, {})
+
+ # Add external_captures of cond to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+
+ def wrapped_body(loop_counter, *args):
+ """Loop body augmented with counter update.
+
+ Args:
+ loop_counter: Loop counter which needs to be incremented in the body.
+ *args: List of args
+ args[:num_outputs] - Args for the original loop body.
+ args[num_outputs:] - External captures of cond. These get passed
+ through as is.
+
+ Returns:
+ A list of tensors the same length as args.
+ """
+ outputs = body(*args[:num_outputs])
+ if not isinstance(outputs, collections.Sequence):
+ outputs = [outputs]
+
+ # Return the external_captures of cond_graph as is, i.e., treat them as
+ # loop invariants.
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
+
+ body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
+ flattened_loop_vars, {})
+ # Add external captures of body to the list of loop vars.
+ # Note that external tensors will be treated as loop invariants, i.e.,
+ # the value of that tensor in each iteration is the same as it was at the
+ # beginning of the loop execution.
+ flattened_loop_vars = flattened_loop_vars + body_graph.external_captures
+ # TODO(srbs): Update lowering code to create _Enter nodes with
+ # is_constant=True for inputs that are directly passed to outputs.
+ body_graph.outputs.extend(body_graph.internal_captures)
+
+ # Capture `external_captures` of `body_graph` in `cond_graph` so that it
+ # expects to receive those as arguments.
+ # TODO(srbs): Dedup tensors that are captured in both the cond and body.
+ # This logic already exists in cond_v2.
+ with cond_graph.as_default():
+ for external_capture in body_graph.external_captures:
+ cond_graph.capture(external_capture)
+
+ # Export all tensors in the loop body that may be needed for gradient
+ # computation. We do this by accumulating the intermediate values in
+ # TensorLists.
+ intermediate_tensors = _get_intermediates(body_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ # TODO(srbs): Cache and re-use empty tensor lists.
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(
+ intermediate_tensor.shape))
+ flattened_loop_vars.append(tensor_list)
+ with cond_graph.as_default():
+ # Add a placeholder to cond_graph's inputs corresponding to the
+ # tensor_list.
+ cond_graph.capture(tensor_list)
+ with body_graph.as_default():
+ # Push the intermediate tensor to the tensor list. This captures the
+ # `tensor_list` as well.
+ appended_tensor_list = list_ops.tensor_list_push_back(
+ tensor_list,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_graph.outputs.append(appended_tensor_list)
+
+ outputs = gen_functional_ops._while(
+ flattened_loop_vars,
+ cond_v2._create_new_tf_function(cond_graph),
+ cond_v2._create_new_tf_function(body_graph),
+ name=scope)
+
+ _copy_handle_data(body_graph.outputs, outputs)
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # First var is loop counter.
+ if num_outputs == 1:
+ return outputs[1]
+ else:
+ return nest.pack_sequence_as(loop_vars, outputs[1:1 + num_outputs])
+
+
+@ops.RegisterGradient("While")
+def _WhileGrad(op, *grads): # pylint: disable=invalid-name
+ """The gradient of a While op produced by while_loop."""
+ body_graph = _get_body_graph(op)
+
+ # Replace None gradients with zeros. This is needed because `grads` could have
+ # None incoming gradients for the TensorLists. If we pass None's through, the
+ # custom gradient of TensorListPopBack will create an EmptyTensorList inside
+ # the FuncGraph which is undesirable.
+ # TODO(b/80444525): There might be an issue with treating no gradient as zero
+ # gradient in certain cases. Consider replacing None gradients with Zeros
+ # for accumulators only.
+ grads = [
+ g if g is not None else array_ops.zeros_like(output)
+ for g, output in zip(grads, op.outputs)
+ ]
+
+ body_grad_graph, args = _create_grad_func(
+ body_graph, grads,
+ _get_unique_name("%s_grad" % body_graph.name), op)
+
+ intermediate_tensors = _get_intermediates(body_grad_graph)
+
+ for intermediate_tensor in intermediate_tensors:
+ tensor_list = list_ops.empty_tensor_list(
+ element_dtype=intermediate_tensor.dtype,
+ element_shape=_get_tensor_convertible_shape(intermediate_tensor.shape))
+ with body_grad_graph.as_default():
+ tensor_list_ph = body_grad_graph.capture(tensor_list, whitelisted=True)
+ # Push the intermediate tensor to the tensor list.
+ appended_tensor_list = list_ops.tensor_list_push_back(tensor_list_ph,
+ intermediate_tensor)
+ # Add this modified tensor list to the list of outputs.
+ body_grad_graph.outputs.append(appended_tensor_list)
+
+ def grad_cond(counter, max_iters, *unused_args):
+ return counter < max_iters
+
+ loop_vars = args + body_grad_graph.external_captures
+ cond_grad_graph = function.func_graph_from_py_func(
+ _get_unique_name("%s_grad_cond" % op.name),
+ grad_cond, loop_vars, {})
+
+ assert len(loop_vars) == len(body_grad_graph.inputs)
+ assert len(loop_vars) == len(body_grad_graph.outputs)
+ assert len(loop_vars) == len(cond_grad_graph.inputs)
+
+ outputs = gen_functional_ops._while(
+ loop_vars,
+ cond_v2._create_new_tf_function(cond_grad_graph),
+ cond_v2._create_new_tf_function(body_grad_graph),
+ name=_get_unique_name("%s_grad" % op.name))
+
+ _copy_handle_data(body_grad_graph.outputs, outputs)
+ _maybe_set_lowering_attr(outputs[0].op)
+
+ # outputs[0] is the loop counter.
+ # outputs[1] is the total number of loop iterations.
+ return outputs[2:2 + len(op.inputs)]
+
+
+# TODO(srbs): Pull this into common utils for cond_v2 and while_v2.
+def _get_body_graph(while_op):
+ """Returns `FuncGraph` for the while body.
+
+ Args:
+ while_op: The While Operation.
+
+ Returns:
+ `FuncGraph` for the while body.
+ """
+ extra_inputs = list(while_op.inputs)
+ input_shapes = [t.shape for t in extra_inputs]
+ func_name = while_op.get_attr("body").name
+ fdef = while_op.graph._get_function(func_name).definition
+ func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
+ func_graph._while = while_op
+ return func_graph
+
+
+def _create_grad_func(func_graph, grads, name, while_op):
+ """Builds and returns the gradient FuncGraph of `func_graph` and its args.
+
+ The returned grad_func_graph must be called with the returned
+ args + grad_func_graph.captures.
+
+ Args:
+ func_graph: FuncGraph for the forward body function.
+ grads: The incoming grads for `func_graph`'s outputs.
+ name: Name of the returned gradient function.
+ while_op: The forward While op.
+
+ Returns:
+ 2-tuple of (grad_func_graph, args).
+ """
+ assert len(func_graph.outputs) == len(grads)
+
+ loop_counter = constant_op.constant(0.)
+ # TODO(srbs): For nested while loops will need to lookup this value from
+ # the accumulator of the enclosing while loop. For now use as is assuming
+ # there is no nesting.
+ num_iters_t = while_op.outputs[0]
+
+ args = [loop_counter, num_iters_t] + grads
+
+ # Note: The returned function does not have `args` in the list of
+ # `external_captures`.
+ grad_func_graph = function.func_graph_from_py_func(
+ name,
+ lambda *args: _grad_fn(func_graph, args),
+ args, {},
+ func_graph=_WhileBodyGradFuncGraph(name, func_graph))
+
+ # Add the popped accumulators to the list of outputs.
+ for internal_capture in grad_func_graph.internal_captures:
+ grad_func_graph.outputs.append(
+ grad_func_graph.popped_tensor_lists[internal_capture])
+
+ return grad_func_graph, args
+
+
+def _grad_fn(func_graph, args):
+ """Computes the gradient of `func_graph` in the current graph.
+
+ This function builds the gradient graph of the corresponding forward-pass
+ `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs.
+
+ Args:
+ func_graph: function.FuncGraph. The corresponding forward-pass function.
+ args: The input arguments. args[0] - Loop counter args[1] - Total number of
+ iterations.
+ args[2:] - Incoming gradients for `func_graph.outputs`.
+
+ Returns:
+ The output gradient Tensors.
+ """
+ xs = func_graph.inputs
+ ys = func_graph.outputs
+ grad_ys = args[2:]
+
+ # Build the gradient graph. Note that this builds the gradient computation of
+ # func_graph in the current graph, which requires capturing tensors from
+ # func_graph. The captured func_graph tensors are resolved to external tensors
+ # in _resolve_grad_inputs.
+ # TODO(srbs): Mark GradientsHelper as public?
+ grad_outs = gradients_impl._GradientsHelper(
+ ys, xs, grad_ys=grad_ys, src_graph=func_graph)
+
+ assert all([g is not None for g in grad_outs])
+ counter = args[0]
+ total_iters = args[1]
+ return [counter + 1, total_iters] + grad_outs
+
+
+def _get_intermediates(func_graph):
+ """Returns all tensors in `func_graph` that should be accumulated."""
+ # We currently accumulate output tensors of most ops in the function and rely
+ # on the pruning pass to get rid of the unused accumulators at runtime.
+ # However, this can bloat the GraphDef and make debugging harder so we perform
+ # some optimizations.
+ #
+ # Optimization we currently perform:
+ # 1. We do not accumulate tensors which already have an accumulator
+ # in the loop body.
+ # 2. We do not accumulate outputs of Identity nodes. When building the
+ # FuncGraph, we add an Identity node for each output (see
+ # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs
+ # of all these nodes bloats the GraphDef quite a bit so we remove those.
+ # Since the gradient of an Identity node does not rely on its forward op's
+ # input this is safe to do.
+ #
+ # Other possible optimizations:
+ # 1. Only accumulate tensors that will be required by the backward pass.
+ # This will require running the gradient pass and hence would increase the
+ # graph building time for the forward pass.
+ # 2. Do not accumulate Const nodes created inside the loop body.
+ # 3. Do not accumulate inputs that are passed as-is, e.g. loop invariants.
+ # TODO(srbs): 2 and 3 may be hard optimizations for the runtime optimizer
+ # since it requires knowledge of the while loop semantics. If so, consider
+ # doing those here.
+ intermediates = []
+
+ for op in func_graph.get_operations():
+ if op.type == "Identity":
+ continue
+ for o in op.outputs:
+ if (o != func_graph.inputs[0] and # Loop counter.
+ _get_accumulator(o) is None): # Has existing accumulator.
+ intermediates.append(o)
+ return intermediates
+
+
+def _get_accumulator(tensor):
+ r"""Returns TensorList if any containing accumulated values of tensor.
+
+ We try to find a pattern of the form:
+
+ input_tl tensor
+ \ /
+ (TensorListPushBack)
+ |
+ output_tl
+
+ which satisfies the following conditions:
+
+ 1. input_tl must be in tensor.graph.inputs.
+ 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs.
+ 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t).
+
+ output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is
+ returned if such a pattern is found else None is returned.
+
+ Args:
+ tensor: The Tensor to be accumulated.
+
+ Returns:
+ A variant tensor in the same graph as `tensor` or None if no accumulator is
+ found.
+ """
+ assert isinstance(tensor.graph, function.FuncGraph)
+
+ def get_func_graph_output(t):
+ """Returns t or Identity(t) whichever exists in graph outputs else None."""
+ if t in tensor.graph.outputs:
+ return t
+ # tf.defun adds an Identity for each output, check whether that is the case.
+ identity_op = t.consumers()[0]
+ if (identity_op.type == "Identity" and
+ identity_op.outputs[0] in tensor.graph.outputs):
+ return identity_op.outputs[0]
+ return None
+
+ for consumer in tensor.consumers():
+ # Find the consumer that is a TensorListPushBack node whose TensorList input
+ # is in the list of function inputs.
+ if (consumer.type != "TensorListPushBack" or
+ consumer.inputs[0] not in tensor.graph.inputs):
+ continue
+
+ output = get_func_graph_output(consumer.outputs[0])
+ if output is None:
+ # The TensorList output of `consumer` is not in the list of function
+ # outputs.
+ continue
+
+ accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0])
+ accum_output_idx = tensor.graph.outputs.index(output)
+ if accum_input_idx == accum_output_idx:
+ return output
+ return None
+
+
+# TODO(srbs): Add to common utils for cond_v2 and while_v2.
+def _get_unique_name(name):
+ """Returns a name that is unique in the root graph of `func_graph`.
+
+ Args:
+ name: String to uniquify.
+
+ Returns:
+ A string.
+ """
+ with ops.init_scope():
+ return ops.get_default_graph().unique_name(name)
+
+
+class _WhileBodyGradFuncGraph(function.FuncGraph):
+ """FuncGraph for the gradient function of the body of a While op.
+
+ Contains the logic for capturing the tensors from the body of the forward
+ While op which is as follows:
+ 1. Find the accumulator for that tensor.
+ 2. Capture the forward While op output tensor corresponding to the
+ accumulator in this FuncGraph.
+ 3. Pop a value from the captured placeholder and use it as the captured value
+ for the forward pass tensor.
+
+ This only allows capturing tensors in the forward graph. A ValueError is
+ raised if an attempt is made to capture a tensor not in the forward graph.
+ To manually capture capture a tensor that is not in the forward graph, call
+ `capture` with `whitelisted=True`.
+
+ Note: The `captures` dict does not contain the forward tensor since it is not
+ directly captured. It contains the accumulator corresponding to this forward
+ tensor.
+
+ Attributes:
+ popped_tensor_lists: Dict from the captured accumulator placeholder to the
+ TensorList obtained after popping the intermediate tensor from it. The
+ values of this dict need to be added to the list of outputs.
+ """
+
+ def __init__(self, name, forward_graph):
+ super(_WhileBodyGradFuncGraph, self).__init__(name)
+ self.popped_tensor_lists = {}
+ # FuncGraph for the body of the forward While op.
+ self._forward_graph = forward_graph
+ # Dict from forward intermediate tensor to the corresponding "popped" tensor
+ # in this graph.
+ self._indirect_captures = {}
+ # Dict from forward graph tensor to the While op output corresponding to its
+ # accumulator.
+ self._tensor_to_accumulator = {}
+
+ def capture(self, tensor, name=None, whitelisted=False):
+ """Selectively captures external tensors.
+
+ If `whitelisted` is False only allows capturing tensors in the
+ `_forward_graph`.
+
+ Args:
+ tensor: Tensor. May be from this FuncGraph or a different graph.
+ name: Optional name if a placeholder is created.
+ whitelisted: If False (default), only allows capturing tensors from the
+ forward graph.
+
+ Returns:
+ The placeholder in this graph for the tensor.
+
+ Raises:
+ ValueError: If attempting to capture an external tensor not in the forward
+ graph with `whitelisted` set to False.
+ """
+ if (not whitelisted and tensor.graph is not self and
+ tensor.graph != self._forward_graph):
+ raise ValueError("Attempting to capture tensor", str(tensor),
+ " which is not in the forward graph but in ",
+ _graph_name(tensor.graph), ".")
+ return super(_WhileBodyGradFuncGraph, self).capture(tensor, name)
+
+ def _capture_helper(self, tensor, name):
+ if tensor.graph is not self._forward_graph:
+ return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name)
+
+ captured_tensor = self._indirect_captures.get(tensor)
+ if captured_tensor is not None:
+ # For GradientTape housekeeping.
+ assert self._tensor_to_accumulator[tensor] in self.captures
+ super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ self._tensor_to_accumulator[tensor], name)
+ return captured_tensor
+
+ assert tensor not in self._tensor_to_accumulator
+
+ accumulator = None
+
+ # Find the TensorList that was used to accumulate the tensors of this
+ # intermediate tensor.
+ accumulator = _get_accumulator(tensor)
+ if accumulator is None:
+ raise ValueError("Reference to un-accumulated intermediate tensor: ",
+ tensor.name)
+ assert accumulator.graph == self._forward_graph
+ # Get the While op output corresponding to the accumulator.
+ accumulator = self._forward_graph._while.outputs[self._forward_graph.outputs
+ .index(accumulator)]
+
+ assert accumulator.graph == self._forward_graph.outer_graph
+ self._tensor_to_accumulator[tensor] = accumulator
+
+ # Capture the `accumulator`.
+ accumulator_ph = super(_WhileBodyGradFuncGraph, self)._capture_helper(
+ accumulator, name)
+ new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back(
+ accumulator_ph, element_dtype=tensor.dtype)
+ self._indirect_captures[tensor] = captured_tensor
+ self.popped_tensor_lists[accumulator_ph] = new_tensor_list
+ return captured_tensor
+
+
+def _copy_handle_data(src_tensors, tgt_tensors):
+ for src_t, tgt_t in zip(src_tensors, tgt_tensors):
+ function._copy_handle_data(src_t, tgt_t)
+
+
+# TODO(srbs): Move to common utils for cond_v2 and while_v2.
+def _maybe_set_lowering_attr(op):
+ """Sets the flag to enable lowering on the `While` op if necessary.
+
+ Lowering allows while_v2 to avoid some of the limitations of Functions,
+ allowing users to specify devices & colocation inside of while_v2
+ branches, and enabling non-strict evaluation & partial pruning of while_v2
+ branches. This brings while_v2 closer to feature parity with
+ tf.while_loop.
+
+ However, we do not lower `While` in the XLA context because it is easier
+ for XLA to apply its own optimizations when dealing with un-lowered
+ `While` operators than with low-level control flow primitives.
+
+ Args:
+ op: The While op.
+ """
+ if not control_flow_util.IsInXLAContext(op):
+ # pylint: disable=protected-access
+ op._set_attr("_lower_using_switch_merge", attr_value_pb2.AttrValue(b=True))
+ # pylint: enable=protected-access
+
+
+def _get_tensor_convertible_shape(shape):
+ assert isinstance(shape, tensor_shape.TensorShape)
+ if shape.is_fully_defined():
+ return shape
+ if not shape: # Unknown shape.
+ return -1
+ # Partially defined shape.
+ shape_list = shape.as_list()
+ shape_list = [s if s is not None else -1 for s in shape_list]
+ return ops.convert_to_tensor(shape_list)
+
+
+def _graph_name(graph):
+ if isinstance(graph, function.FuncGraph):
+ return graph.name
+ return "Base"
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
index 45de047894..5927bc2409 100644
--- a/tensorflow/python/platform/gfile.py
+++ b/tensorflow/python/platform/gfile.py
@@ -33,6 +33,7 @@ from tensorflow.python.lib.io.file_io import rename as Rename
from tensorflow.python.lib.io.file_io import stat as Stat
from tensorflow.python.lib.io.file_io import walk as Walk
# pylint: enable=unused-import
+from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -62,6 +63,7 @@ class FastGFile(_FileIO):
invocations in network filesystems).
"""
+ @deprecated(None, 'Use tf.gfile.GFile.')
def __init__(self, name, mode='r'):
super(FastGFile, self).__init__(name=name, mode=mode)
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index c0e16ca536..94c685274a 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -52,13 +52,19 @@ builder = option_builder.ProfileOptionBuilder
class PrintModelAnalysisTest(test.TestCase):
+ def _no_rewrite_session_config(self):
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ pin_to_host_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ return config_pb2.ConfigProto(graph_options=graph_options)
+
def testDumpToFile(self):
ops.reset_default_graph()
outfile = os.path.join(test.get_temp_dir(), 'dump')
opts = builder(builder.trainable_variables_parameter()
).with_file_output(outfile).build()
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
_ = lib.BuildSmallModel()
model_analyzer.profile(sess.graph, options=opts)
@@ -83,7 +89,8 @@ class PrintModelAnalysisTest(test.TestCase):
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
- with session.Session() as sess, ops.device(dev):
+ with session.Session(
+ config=self._no_rewrite_session_config()) as sess, ops.device(dev):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -149,11 +156,8 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['params', 'float_ops', 'occurrence', 'device', 'op_types',
'input_shapes']).build())
- rewriter_config = rewriter_config_pb2.RewriterConfig(
- disable_model_pruning=True)
- graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
- config = config_pb2.ConfigProto(graph_options=graph_options)
- with session.Session(config=config) as sess, ops.device('/device:CPU:0'):
+ with session.Session(config=self._no_rewrite_session_config()
+ ) as sess, ops.device('/device:CPU:0'):
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -179,7 +183,7 @@ class PrintModelAnalysisTest(test.TestCase):
.select(['bytes', 'params', 'float_ops', 'num_hidden_ops', 'device',
'input_shapes']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -213,7 +217,7 @@ class PrintModelAnalysisTest(test.TestCase):
with profile_context.ProfileContext(test.get_temp_dir(),
trace_steps=[],
dump_steps=[]) as pctx:
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -274,7 +278,7 @@ class PrintModelAnalysisTest(test.TestCase):
.account_displayed_op_only(False)
.select(['bytes', 'params', 'float_ops', 'device']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
@@ -302,7 +306,7 @@ class PrintModelAnalysisTest(test.TestCase):
.with_timeline_output(outfile)
.with_accounted_types(['.*']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -338,7 +342,7 @@ class PrintModelAnalysisTest(test.TestCase):
'peak_bytes', 'residual_bytes',
'output_bytes', 'occurrence', 'input_shapes']).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -384,7 +388,7 @@ class PrintModelAnalysisTest(test.TestCase):
def testAdvisor(self):
ops.reset_default_graph()
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -417,7 +421,7 @@ class PrintModelAnalysisTest(test.TestCase):
.with_node_names(trim_name_regexes=['ops.py.*'])
.with_pprof_output(outfile).build())
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildFullModel()
sess.run(variables.global_variables_initializer())
@@ -484,7 +488,7 @@ class PrintModelAnalysisTest(test.TestCase):
self.assertGreaterEqual(n.output_bytes, mob)
check_min(n.children, mm, mam, mcm, mb, mpb, mrb, mob)
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -549,7 +553,7 @@ class PrintModelAnalysisTest(test.TestCase):
for attr in not_selected:
self.assertFalse(s.find(attr) > 0, s)
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
x = lib.BuildSmallModel()
sess.run(variables.global_variables_initializer())
run_meta = config_pb2.RunMetadata()
@@ -582,7 +586,7 @@ class PrintModelAnalysisTest(test.TestCase):
def _trainLoop(self, train_op, train_steps, time_dir, time_step,
memory_dir, memory_step, profile_dir, dump_step):
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
sess.run(variables.global_variables_initializer())
# start from 1 because variable_initializer took one step.
for i in range(1, train_steps + 1):
@@ -655,7 +659,7 @@ class PrintModelAnalysisTest(test.TestCase):
c = a * b
try:
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
sess.run(c, options=config_pb2.RunOptions(
report_tensor_allocations_upon_oom=True))
except Exception as e: # pylint: disable=broad-except
@@ -758,7 +762,7 @@ class PrintModelAnalysisTest(test.TestCase):
grad = gradients.gradients(y, [x1])
- with session.Session() as sess:
+ with session.Session(config=self._no_rewrite_session_config()) as sess:
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
run_metadata = config_pb2.RunMetadata()
diff --git a/tensorflow/python/profiler/pprof_profiler_test.py b/tensorflow/python/profiler/pprof_profiler_test.py
index c2469f012d..11a3487360 100644
--- a/tensorflow/python/profiler/pprof_profiler_test.py
+++ b/tensorflow/python/profiler/pprof_profiler_test.py
@@ -141,7 +141,7 @@ comment: 9
run_metadata = config_pb2.RunMetadata()
num_iters = 5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i = constant_op.constant(0)
c = lambda i: math_ops.less(i, num_iters)
b = lambda i: math_ops.add(i, 1)
diff --git a/tensorflow/python/pywrap_tensorflow.py b/tensorflow/python/pywrap_tensorflow.py
index 5c0c5783dc..f0724277d3 100644
--- a/tensorflow/python/pywrap_tensorflow.py
+++ b/tensorflow/python/pywrap_tensorflow.py
@@ -68,7 +68,7 @@ try:
sys.setdlopenflags(_default_dlopen_flags)
except ImportError:
msg = """%s\n\nFailed to load the native TensorFlow runtime.\n
-See https://www.tensorflow.org/install/install_sources#common_installation_problems\n
+See https://www.tensorflow.org/install/errors\n
for some common reasons and solutions. Include the entire stack trace
above this error message when asking for help.""" % traceback.format_exc()
raise ImportError(msg)
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index be8f425481..c411a58b70 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -188,7 +188,10 @@ limitations under the License.
"outputs of the operation)");
}
$1 = &temp;
- $1->resize(PyInt_AsLong($input), nullptr);
+ long sz = PyInt_AsLong($input);
+ if (sz > 0) {
+ $1->resize(PyInt_AsLong($input), nullptr);
+ }
}
// Create new Status object.
diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md
index 5eeaf73a43..fe69f3beb0 100644
--- a/tensorflow/python/saved_model/README.md
+++ b/tensorflow/python/saved_model/README.md
@@ -91,10 +91,17 @@ with an asset of the same name, only the first version is retained.
#### Tags
Each meta graph added to the SavedModel must be annotated with user specified
-tags. The tags provide a means to identify the specific meta graph to load and
-restore, along with the shared set of variables and assets. These tags
-typically annotate a MetaGraph with its functionality (e.g. serving or
-training), and possibly hardware specific aspects such as GPU.
+tags, which reflect the meta graph capabilities or use-cases.
+More specifically, these tags typically annotate a meta graph with its
+functionality (e.g. serving or training), and possibly hardware specific aspects
+such as GPU.
+In the SavedModel, the meta graph def whose tag-set exactly matches those
+specified in the loader API, will be the one loaded by the loader.
+If no meta graph def is found matching the specified tags, an error is returned.
+For example, a loader with a requirement to serve on GPU hardware would be able
+to load only meta graph annotated with tags='serve,gpu' by specifying this set
+of tags in tensorflow::LoadSavedModel(...).
+
#### Usage
The typical usage of `builder` is as follows:
diff --git a/tensorflow/python/summary/writer/event_file_writer.py b/tensorflow/python/summary/writer/event_file_writer.py
index 2936a279bd..14dec982a6 100644
--- a/tensorflow/python/summary/writer/event_file_writer.py
+++ b/tensorflow/python/summary/writer/event_file_writer.py
@@ -62,7 +62,7 @@ class EventFileWriter(object):
filename_suffix: A string. Every event file's name is suffixed with
`filename_suffix`.
"""
- self._logdir = logdir
+ self._logdir = str(logdir)
if not gfile.IsDirectory(self._logdir):
gfile.MakeDirs(self._logdir)
self._event_queue = six.moves.queue.Queue(max_queue)
diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py
index dc990c2602..670230e917 100644
--- a/tensorflow/python/summary/writer/writer_test.py
+++ b/tensorflow/python/summary/writer/writer_test.py
@@ -286,7 +286,7 @@ class FileWriterTestCase(test.TestCase):
def testAddingSummariesFromSessionRunCalls(self):
test_dir = self._CleanTestDir("global_step")
sw = self._FileWriter(test_dir)
- with self.test_session():
+ with self.cached_session():
i = constant_op.constant(1, dtype=dtypes.int32, shape=[])
l = constant_op.constant(2, dtype=dtypes.int64, shape=[])
# Test the summary can be passed serialized.
@@ -437,7 +437,7 @@ class SessionBasedFileWriterTestCase(FileWriterTestCase):
# Pass in test_session() as the session. It will be cached during this
# test method invocation so that any other use of test_session() with no
# graph should result in re-using the same underlying Session.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
kwargs["session"] = sess
return writer.FileWriter(*args, **kwargs)
return writer.FileWriter(*args, **kwargs)
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 1c1a1a54cd..75824d83e6 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -44,6 +44,7 @@ py_library(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/saved_model:loader",
"@six_archive//:six",
],
diff --git a/tensorflow/python/tools/api/generator/create_python_api.py b/tensorflow/python/tools/api/generator/create_python_api.py
index 67cfd799ff..ab749f28cd 100644
--- a/tensorflow/python/tools/api/generator/create_python_api.py
+++ b/tensorflow/python/tools/api/generator/create_python_api.py
@@ -181,7 +181,6 @@ class _ModuleInitCodeBuilder(object):
_names_with_underscore = [%s]
__all__ = [_s for _s in dir() if not _s.startswith('_')]
__all__.extend([_s for _s in _names_with_underscore])
-__all__.remove('print_function')
''' % underscore_names_str
return module_text_map
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index fcb3ceac82..a39c046761 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -129,7 +129,7 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertProtoEquals(expected_output, output)
def testFoldBatchNorms(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
@@ -161,7 +161,7 @@ class OptimizeForInferenceTest(test.TestCase):
optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
original_graph_def)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -224,7 +224,7 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertNotEqual("FusedBatchNorm", node.op)
def testFuseResizePadAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -242,7 +242,7 @@ class OptimizeForInferenceTest(test.TestCase):
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -255,7 +255,7 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertNotEqual("ResizeBilinear", node.op)
def testFuseResizeAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -271,7 +271,7 @@ class OptimizeForInferenceTest(test.TestCase):
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
@@ -284,7 +284,7 @@ class OptimizeForInferenceTest(test.TestCase):
def testFusePadAndConv(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
input_op = constant_op.constant(
np.array(inputs), shape=[1, 2, 3, 2], dtype=dtypes.float32)
@@ -300,7 +300,7 @@ class OptimizeForInferenceTest(test.TestCase):
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
original_graph_def, ["output"])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_ = importer.import_graph_def(
optimized_graph_def, input_map={}, name="optimized")
optimized_result = sess.run(["optimized/output:0"])
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index c5289564fe..3dbccd1409 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -15,7 +15,7 @@
"""Command-line interface to inspect and execute a graph in a SavedModel.
For detailed usages and examples, please refer to:
-https://www.tensorflow.org/guide/saved_model_cli
+https://www.tensorflow.org/guide/saved_model#cli_to_inspect_and_execute_savedmodel
"""
@@ -33,7 +33,6 @@ import numpy as np
from six import integer_types
from tensorflow.contrib.saved_model.python.saved_model import reader
-from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
@@ -97,8 +96,7 @@ def _get_inputs_tensor_info_from_meta_graph_def(meta_graph_def,
Returns:
A dictionary that maps input tensor keys to TensorInfos.
"""
- return signature_def_utils.get_signature_def_by_key(meta_graph_def,
- signature_def_key).inputs
+ return meta_graph_def.signature_def[signature_def_key].inputs
def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def,
@@ -116,8 +114,7 @@ def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def,
Returns:
A dictionary that maps output tensor keys to TensorInfos.
"""
- return signature_def_utils.get_signature_def_by_key(meta_graph_def,
- signature_def_key).outputs
+ return meta_graph_def.signature_def[signature_def_key].outputs
def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0):
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index 3508b98475..cc0da26b27 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -34,7 +34,7 @@ class AdagradOptimizer(optimizer.Optimizer):
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
or this
- [intro](http://cs.stanford.edu/~ppasupat/a9online/uploads/proximal_notes.pdf).
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
def __init__(self, learning_rate, initial_accumulator_value=0.1,
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 21ca1735e0..419a9ec12b 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -195,6 +195,10 @@ class _SameScopeAgainContext(object):
class DistributionStrategy(object):
"""A list of devices with a state & compute distribution policy.
+ See [tensorflow/contrib/distribute/README.md](
+ https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
+ for overview and examples.
+
The intent is that you can write an algorithm in a stylized way and
it will be usable with a variety of different `DistributionStrategy`
implementations. Each descendant will implement a different strategy
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index 09d6fe36d3..15c50bc878 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -218,7 +218,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL1_L2_L2ShrinkageSparse(self):
"""Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -252,7 +252,7 @@ class FtrlOptimizerTest(test.TestCase):
def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
"""Verifies that l2 shrinkage in FTRL does not change lr schedule."""
for dtype in [dtypes.half, dtypes.float32]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([1.0, 2.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
index 56d82a5b88..1ddea598e5 100644
--- a/tensorflow/python/training/gradient_descent_test.py
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -252,12 +252,12 @@ class GradientDescentOptimizerTest(test.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(1.0)
def step():
- v = resource_variable_ops.ResourceVariable(1.0)
+ self.v = resource_variable_ops.ResourceVariable(1.0)
with backprop.GradientTape() as tape:
- loss = v ** 2
- grad = tape.gradient(loss, v)
- optimizer.apply_gradients([(grad, v)])
- return v.read_value()
+ loss = self.v ** 2
+ grad = tape.gradient(loss, self.v)
+ optimizer.apply_gradients([(grad, self.v)])
+ return self.v.read_value()
compiled_step = function.defun(step)
diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py
index 0f2d60dafc..b2ac93f06f 100644
--- a/tensorflow/python/training/learning_rate_decay_v2_test.py
+++ b/tensorflow/python/training/learning_rate_decay_v2_test.py
@@ -62,7 +62,7 @@ class LRDecayTestV2(test_util.TensorFlowTestCase):
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
def testVariables(self):
- with self.test_session():
+ with self.cached_session():
step = variables.Variable(1)
assign_1 = step.assign(1)
assign_2 = step.assign(2)
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 0e0125a956..82f0e3be52 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -1114,7 +1114,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised while a session was being created. '
'This may be due to a preemption of a connected worker '
'or parameter server. A new session will be created. '
- 'Error: %s', e)
+ 'This error may also occur due to a gRPC failure caused '
+ 'by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
def _check_stop(self):
try:
@@ -1127,7 +1131,11 @@ class _RecoverableSession(_WrappedSession):
'session is complete. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = self._create_session()
# Since we have just recreated the session, the overall computation should
@@ -1150,7 +1158,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = None
@@ -1166,7 +1178,11 @@ class _RecoverableSession(_WrappedSession):
logging.info('An error was raised. This may be due to a preemption in '
'a connected worker or parameter server. The current '
'session will be closed and a new session will be '
- 'created. Error: %s', e)
+ 'created. This error may also occur due to a gRPC failure '
+ 'caused by high memory or network bandwidth usage in the '
+ 'parameter servers. If this error occurs repeatedly, try '
+ 'increasing the number of parameter servers assigned to '
+ 'the job. Error: %s', e)
self.close()
self._sess = None
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 2304a461c1..699162b30c 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -385,13 +385,12 @@ class Optimizer(
@compatibility(eager)
When eager execution is enabled, `loss` should be a Python function that
- takes elements of `var_list` as arguments and computes the value to be
- minimized. If `var_list` is None, `loss` should take no arguments.
- Minimization (and gradient computation) is done with respect to the
- elements of `var_list` if not None, else with respect to any trainable
- variables created during the execution of the `loss` function.
- `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
- `grad_loss` are ignored when eager execution is enabled.
+ takes no arguments and computes the value to be minimized. Minimization (and
+ gradient computation) is done with respect to the elements of `var_list` if
+ not None, else with respect to any trainable variables created during the
+ execution of the `loss` function. `gate_gradients`, `aggregation_method`,
+ `colocate_gradients_with_ops` and `grad_loss` are ignored when eager
+ execution is enabled.
@end_compatibility
"""
grads_and_vars = self.compute_gradients(
diff --git a/tensorflow/python/training/quantize_training.i b/tensorflow/python/training/quantize_training.i
index 41e62e0252..1ab600bb22 100644
--- a/tensorflow/python/training/quantize_training.i
+++ b/tensorflow/python/training/quantize_training.i
@@ -55,6 +55,13 @@ PyObject* DoQuantizeTrainingOnGraphDefHelper(
%insert("python") %{
+from tensorflow.python.util import deprecation
+from tensorflow.python.util.tf_export import tf_export
+
+@deprecation.deprecated(
+ None,
+ "GraphDef quantized training rewriter is deprecated in the long term")
+@tf_export(v1=["train.do_quantize_training_on_graphdef"])
def do_quantize_training_on_graphdef(input_graph, num_bits):
"""A general quantization scheme is being developed in `tf.contrib.quantize`.
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 274c856686..5b2b19e913 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -622,6 +622,14 @@ class BaseSaverBuilder(object):
yield BaseSaverBuilder.ResourceVariableSaveable(
variable, variable._save_slice_info.spec, name)
# pylint: enable=protected-access
+ elif isinstance(op, checkpointable.CheckpointableBase) and not isinstance(
+ op, variables.Variable):
+ # pylint: disable=protected-access
+ for attr, factory in op._gather_saveables_for_checkpoint().items():
+ op = (factory(name + "_" + attr) if callable(factory) else factory)
+ for op in BaseSaverBuilder.SaveableObjectsForOp(op, op.name):
+ yield op
+ # pylint: enable=protected-access
else:
# A variable or tensor.
if context.executing_eagerly():
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 0ac84813c8..69b1055ebe 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -2850,30 +2850,32 @@ class CheckpointableCompatibilityTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testNotSaveableButIsCheckpointable(self):
v = _OwnsAVariableSimple()
- saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- with self.cached_session() as sess:
- self.evaluate(v.non_dep_variable.assign(42.))
- save_path = saver.save(sess, prefix)
- self.evaluate(v.non_dep_variable.assign(43.))
- saver.restore(sess, save_path)
- self.assertEqual(42., self.evaluate(v.non_dep_variable))
+ for saver in (saver_module.Saver(var_list=[v]),
+ saver_module.Saver(var_list={"v": v})):
+ with self.cached_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
+ save_path = saver.save(sess, prefix)
+ self.evaluate(v.non_dep_variable.assign(43.))
+ saver.restore(sess, save_path)
+ self.assertEqual(42., self.evaluate(v.non_dep_variable))
@test_util.run_in_graph_and_eager_modes
def testMoreComplexSaveableReturned(self):
v = _OwnsMirroredVariables()
- saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
self.evaluate(v.non_dep_variable.assign(42.))
- with self.cached_session() as sess:
- save_path = saver.save(sess, prefix)
- self.evaluate(v.non_dep_variable.assign(43.))
- self.evaluate(v.mirrored.assign(44.))
- saver.restore(sess, save_path)
- self.assertEqual(42., self.evaluate(v.non_dep_variable))
- self.assertEqual(42., self.evaluate(v.mirrored))
+ for saver in (saver_module.Saver(var_list=[v]),
+ saver_module.Saver(var_list={"v": v})):
+ with self.cached_session() as sess:
+ save_path = saver.save(sess, prefix)
+ self.evaluate(v.non_dep_variable.assign(43.))
+ self.evaluate(v.mirrored.assign(44.))
+ saver.restore(sess, save_path)
+ self.assertEqual(42., self.evaluate(v.non_dep_variable))
+ self.assertEqual(42., self.evaluate(v.mirrored))
def testSingleTensorEvaluation(self):
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index 6c860cd452..3eddf79e34 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -203,7 +203,7 @@ class WarmStartingUtilTest(test.TestCase):
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
@@ -279,7 +279,7 @@ class WarmStartingUtilTest(test.TestCase):
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.],
@@ -337,7 +337,7 @@ class WarmStartingUtilTest(test.TestCase):
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
@@ -403,7 +403,7 @@ class WarmStartingUtilTest(test.TestCase):
"new_vocab")
# New session and new graph.
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as sess:
+ with self.session(graph=g) as sess:
fruit_output_layer = variable_scope.get_variable(
"fruit_output_layer",
shape=[4, 3],
diff --git a/tensorflow/python/util/memory.py b/tensorflow/python/util/memory.py
new file mode 100644
index 0000000000..e78f6d509a
--- /dev/null
+++ b/tensorflow/python/util/memory.py
@@ -0,0 +1,45 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functions related to Python memory management."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+# TODO(b/115366440): Delete this function when a custom OrderedDict is added
+def dismantle_ordered_dict(ordered_dict):
+ """Remove reference cycle in OrderedDict `ordered_dict`.
+
+ Helpful for making sure the garbage collector doesn't need to run after
+ using an OrderedDict.
+
+ Args:
+ ordered_dict: A `OrderedDict` object to destroy. This object is unusable
+ after this function runs.
+ """
+ # OrderedDict, makes a simple reference loop
+ # and hides it in an __attribute in some Python versions. We don't need to
+ # throw an error if we can't find it, but if we do find it we can break the
+ # loop to avoid creating work for the garbage collector.
+ problematic_cycle = ordered_dict.__dict__.get("_OrderedDict__root", None) # pylint: disable=protected-access
+ if problematic_cycle:
+ try:
+ del problematic_cycle[0][:]
+ except TypeError:
+ # This is probably not one of the problematic Python versions. Continue
+ # with the rest of our cleanup.
+ pass
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 2968ca9c07..653ca525dc 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -118,6 +118,18 @@ flatten = _pywrap_tensorflow.Flatten
_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
+class _DotString(object):
+
+ def __str__(self):
+ return "."
+
+ def __repr__(self):
+ return "."
+
+
+_DOT = _DotString()
+
+
def assert_same_structure(nest1, nest2, check_types=True):
"""Asserts that two structures are nested in the same way.
@@ -149,7 +161,15 @@ def assert_same_structure(nest1, nest2, check_types=True):
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
- _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
+ try:
+ _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
+ except (ValueError, TypeError) as e:
+ str1 = str(map_structure(lambda _: _DOT, nest1))
+ str2 = str(map_structure(lambda _: _DOT, nest2))
+ raise type(e)("%s\n"
+ "Entire first structure:\n%s\n"
+ "Entire second structure:\n%s"
+ % (str(e), str1, str2))
def flatten_dict_items(dictionary):
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index ef503137d1..bfb4c6f910 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -264,7 +264,11 @@ class NestTest(parameterized.TestCase, test.TestCase):
"Second structure:.*\n\n"
"More specifically: Substructure "
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
- 'substructure "type=str str=spam" is not')):
+ 'substructure "type=str str=spam" is not\n'
+ "Entire first structure:\n"
+ r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
+ "Entire second structure:\n"
+ r"\(\., \.\)")):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 778121e15b..967c872c2a 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -325,6 +325,11 @@ def isfunction(object): # pylint: disable=redefined-builtin
return _inspect.isfunction(tf_decorator.unwrap(object)[1])
+def isgenerator(object): # pylint: disable=redefined-builtin
+ """TFDecorator-aware replacement for inspect.isgenerator."""
+ return _inspect.isgenerator(tf_decorator.unwrap(object)[1])
+
+
def ismethod(object): # pylint: disable=redefined-builtin
"""TFDecorator-aware replacement for inspect.ismethod."""
return _inspect.ismethod(tf_decorator.unwrap(object)[1])
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 6d336ac39d..104a615636 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -104,9 +104,36 @@ Raises:
%unignore tensorflow::swig::Flatten;
%noexception tensorflow::swig::Flatten;
+%feature("docstring") tensorflow::swig::IsSequenceForData
+"""Returns a true if `seq` is a Sequence or dict (except strings/lists).
+
+NOTE(mrry): This differs from `tensorflow.python.util.nest.is_sequence()`,
+which *does* treat a Python list as a sequence. For ergonomic
+reasons, `tf.data` users would prefer to treat lists as
+implicit `tf.Tensor` objects, and dicts as (nested) sequences.
+
+Args:
+ seq: an input sequence.
+
+Returns:
+ True if the sequence is a not a string or list and is a
+ collections.Sequence.
+"""
%unignore tensorflow::swig::IsSequenceForData;
%noexception tensorflow::swig::IsSequenceForData;
+%feature("docstring") tensorflow::swig::FlattenForData
+"""Returns a flat sequence from a given nested structure.
+
+If `nest` is not a sequence, this returns a single-element list: `[nest]`.
+
+Args:
+ nest: an arbitrarily nested structure or a scalar object.
+ Note, numpy arrays are considered scalars.
+
+Returns:
+ A Python list, the flattened version of the input.
+"""
%unignore tensorflow::swig::FlattenForData;
%noexception tensorflow::swig::FlattenForData;
diff --git a/tensorflow/requirements.txt b/tensorflow/requirements.txt
new file mode 100644
index 0000000000..6e111edefc
--- /dev/null
+++ b/tensorflow/requirements.txt
@@ -0,0 +1,2 @@
+keras_applications >= 1.0.5
+keras_preprocessing >= 1.0.3
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 3c533c7f99..ca90c383f9 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/mathutil.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
@@ -132,23 +133,42 @@ string ToString(cudnnStatus_t status) {
}
template <typename T>
-cudnnDataType_t GetCudnnDataType();
+cudnnDataType_t GetCudnnDataType(
+ dnn::DataLayout = dnn::DataLayout::kBatchDepthYX);
template <>
-cudnnDataType_t GetCudnnDataType<double>() {
+cudnnDataType_t GetCudnnDataType<double>(dnn::DataLayout) {
return CUDNN_DATA_DOUBLE;
}
template <>
-cudnnDataType_t GetCudnnDataType<float>() {
+cudnnDataType_t GetCudnnDataType<float>(dnn::DataLayout) {
return CUDNN_DATA_FLOAT;
}
template <>
-cudnnDataType_t GetCudnnDataType<Eigen::half>() {
+cudnnDataType_t GetCudnnDataType<Eigen::half>(dnn::DataLayout) {
return CUDNN_DATA_HALF;
}
+template <>
+cudnnDataType_t GetCudnnDataType<int8>(dnn::DataLayout layout) {
+ switch (layout) {
+ case dnn::DataLayout::kYXDepthBatch:
+ case dnn::DataLayout::kYXBatchDepth:
+ case dnn::DataLayout::kBatchYXDepth:
+ case dnn::DataLayout::kBatchDepthYX:
+ return CUDNN_DATA_INT8;
+ case dnn::DataLayout::kBatchDepthYX4:
+ return CUDNN_DATA_INT8x4;
+ }
+}
+
+template <>
+cudnnDataType_t GetCudnnDataType<int32>(dnn::DataLayout) {
+ return CUDNN_DATA_INT32;
+}
+
// RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
//
// See CudnnAccess::GetHandle() for details.
@@ -2387,6 +2407,33 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
}
}
+// Determines whether we can safely perform a winograd non-fused convolution for
+// the given input and output shapes. This works around b/68264959, an integer
+// overflow in cuDNNv5 and cuDNNv6.
+#if CUDNN_VERSION >= 7000
+bool ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor&,
+ const dnn::BatchDescriptor&) {
+ return true;
+}
+#else
+bool ShouldIncludeWinogradNonfusedAlgo(
+ const dnn::BatchDescriptor& input_desc,
+ const dnn::BatchDescriptor& output_desc) {
+ int64 batch = input_desc.count();
+ int64 in_depths = input_desc.feature_map_count();
+ int64 in_rows = input_desc.height();
+ int64 in_cols = input_desc.ndims() == 1 ? 1 : input_desc.width();
+ int64 out_depths = output_desc.feature_map_count();
+
+ int64 total_size = port::MathUtil::CeilOfRatio(batch, int64{16}) *
+ std::max(in_depths, out_depths) * in_cols * in_rows *
+ sizeof(float);
+
+ const int64 threshold = 1L << 31;
+ return total_size < threshold;
+}
+#endif
+
} // namespace
template <class T>
@@ -2465,6 +2512,13 @@ port::Status CudnnSupport::DoConvolveImpl(
return port::Status::OK();
}());
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
cudnn.handle(),
/*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
@@ -2486,19 +2540,19 @@ port::Status CudnnSupport::DoConvolveImpl(
return port::Status::OK();
}
-template <typename Type, typename BiasType, typename ScaleType,
- int cudnn_data_type, int cudnn_compute_type>
+template <typename AccumulatorType, typename ElementType, typename BiasType,
+ typename ScaleType>
port::Status CudnnSupport::DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
- const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
- const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<Type>& filter_data,
+ const DeviceMemory<ElementType>& conv_input_data,
+ ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<ElementType>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
- const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
- const dnn::BatchDescriptor& bias_descriptor,
+ const DeviceMemory<ElementType>& side_input_data,
+ ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+ DeviceMemory<ElementType>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
if (activation_mode != dnn::ActivationMode::kRelu &&
@@ -2509,14 +2563,17 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
}
CudnnTensorDescriptor conv_input_nd(
- conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
+ conv_input_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
CudnnTensorDescriptor output_nd(
- output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
- CudnnFilterDescriptor filter(filter_descriptor,
- static_cast<cudnnDataType_t>(cudnn_data_type));
- CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
- CudnnConvolutionDescriptor conv(
- convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
+ output_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnFilterDescriptor filter(
+ filter_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
+ CudnnConvolutionDescriptor conv(convolution_descriptor,
+ GetCudnnDataType<AccumulatorType>());
auto cudnn = cudnn_->GetHandle(parent_, stream);
@@ -2566,6 +2623,14 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
<< "\noutput_nd.handle() = " << output_nd.handle()
<< "\noutput_data->opaque() = " << output_data->opaque();
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(conv_input_descriptor,
+ output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See around b/68264959.");
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
cudnn.handle(),
/*alpha1=*/&conv_input_scale,
@@ -2933,8 +2998,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
- CUDNN_DATA_DOUBLE>(
+ DoFusedConvolveImpl<double>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2957,8 +3021,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
- CUDNN_DATA_FLOAT>(
+ DoFusedConvolveImpl<float>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2982,8 +3045,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
- CUDNN_DATA_FLOAT>(
+ DoFusedConvolveImpl<float>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -3014,8 +3076,7 @@ bool CudnnSupport::DoFusedConvolve(
return false;
}
return IsStatusOk(
- DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
- CUDNN_DATA_INT32>(
+ DoFusedConvolveImpl<int32>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -3096,6 +3157,13 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
}
}
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
// Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
// zero-initialized, nvbugs/2254619.
if (CUDNN_VERSION >= 7000 &&
@@ -3275,6 +3343,13 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
"This configuration potentially produces incorrect results.");
}());
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
// Zero out the result buffer for strided conv backward filter for NHWC
// layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not
// zeroed.
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 9d88f971bb..74f6f935b8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -674,19 +674,21 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result);
- template <typename Type, typename BiasType, typename ScaleType,
- int cudnn_data_type, int cudnn_compute_type>
+ template <typename AccumulatorType, typename ElementType, typename BiasType,
+ typename ScaleType>
port::Status DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
- const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
+ const DeviceMemory<ElementType>& conv_input_data,
+ ScaleType conv_input_scale,
const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<Type>& filter_data,
+ const DeviceMemory<ElementType>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
- const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
- const dnn::BatchDescriptor& bias_descriptor,
+ const DeviceMemory<ElementType>& side_input_data,
+ ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+ DeviceMemory<ElementType>* output_data,
+ ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result);
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index 7f99d81ef3..a4580d6462 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -22,8 +22,7 @@ limitations under the License.
#include <map>
#include <memory>
-#include "tensorflow/stream_executor/platform/port.h"
-
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -359,9 +358,8 @@ class DeviceDescriptionBuilder {
bool ThreadDimOk(const DeviceDescription &device_description,
const ThreadDim &thread_dim);
-// [deprecated] Use MathUtil::CeilOfRatio directly instead.
-//
// Equivalent to ceil(double(element_count) / threads_per_block).
+ABSL_DEPRECATED("Use MathUtil::CeilOfRatio directly instead.")
uint64 DivideCeil(uint64 x, uint64 y);
// Calculate the number of threads/blocks required to process element_count
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 9abfa1db6a..621b155240 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -873,7 +873,7 @@ class NormalizeDescriptor {
// Describes a kind of non-linearity (threshold-like mathematical function).
enum class ActivationMode {
- kNone,
+ kNone = 0,
kSigmoid,
// Rectified linear activation: f(x) = x < 0 ? 0 : x
kRelu,
@@ -885,6 +885,8 @@ enum class ActivationMode {
kTanh,
// Like ReluX, but passes all values in the range [-X,X].
kBandPass,
+
+ kNumActivationModes, // Always in the end.
};
// Returns a string representation of the given activation mode.
diff --git a/tensorflow/stream_executor/lib/array_slice.h b/tensorflow/stream_executor/lib/array_slice.h
index 8e3c4ca047..5f4e586762 100644
--- a/tensorflow/stream_executor/lib/array_slice.h
+++ b/tensorflow/stream_executor/lib/array_slice.h
@@ -16,13 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "absl/types/span.h"
namespace stream_executor {
namespace port {
-using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::MutableArraySlice;
+template <typename T>
+using ArraySlice = absl::Span<const T>;
+template <typename T>
+using MutableArraySlice = absl::Span<T>;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/inlined_vector.h b/tensorflow/stream_executor/lib/inlined_vector.h
index 40bdddb180..0198947e5b 100644
--- a/tensorflow/stream_executor/lib/inlined_vector.h
+++ b/tensorflow/stream_executor/lib/inlined_vector.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "absl/container/inlined_vector.h"
namespace stream_executor {
namespace port {
-using tensorflow::gtl::InlinedVector;
+using absl::InlinedVector;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h
index c959e4df5b..3688d7b4eb 100644
--- a/tensorflow/stream_executor/lib/strcat.h
+++ b/tensorflow/stream_executor/lib/strcat.h
@@ -18,13 +18,13 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "absl/strings/str_cat.h"
namespace stream_executor {
namespace port {
-using tensorflow::strings::StrCat;
-using tensorflow::strings::StrAppend;
+using absl::StrAppend;
+using absl::StrCat;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/stringpiece.h b/tensorflow/stream_executor/lib/stringpiece.h
index b80de5df30..7624910129 100644
--- a/tensorflow/stream_executor/lib/stringpiece.h
+++ b/tensorflow/stream_executor/lib/stringpiece.h
@@ -16,13 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/stream_executor/platform/port.h"
+#include "absl/strings/string_view.h"
namespace stream_executor {
namespace port {
-using tensorflow::StringPiece;
+using StringPiece = absl::string_view;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h
index 49628ecd24..3065b5cb77 100644
--- a/tensorflow/stream_executor/plugin_registry.h
+++ b/tensorflow/stream_executor/plugin_registry.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/fft.h"
@@ -97,6 +98,7 @@ class PluginRegistry {
// TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
// on MultiPlatformManager / PlatformId.
template <typename FactoryT>
+ ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
PluginId plugin_id);
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 19d3b2389a..69558fd14b 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -587,6 +587,44 @@ Stream &Stream::ThenConvolveWithScratch(
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
+ const DeviceMemory<double> &conv_input_data, double conv_input_scale,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<double> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<double> &side_input_data, double side_input_scale,
+ const dnn::BatchDescriptor &bias_descriptor,
+ const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
+ ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+ PARAM(conv_input_scale), PARAM(filter_descriptor),
+ PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
+ PARAM(side_input_data), PARAM(side_input_scale),
+ PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
+ PARAM(algorithm_config));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoFusedConvolve(
+ this, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output, scratch_allocator,
+ algorithm_config, output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFusedConvolveWithAlgorithm(
+ const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 9515d8e62a..10bf006787 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <atomic>
#include <utility>
+#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/lib/env.h"
@@ -163,6 +164,15 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind,
CheckPlatformKindIsValid(platform_kind);
}
+// Get per-device memory limit in bytes. Returns 0 if
+// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
+static int64 GetMemoryLimitBytes() {
+ int64 value;
+ SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
+ 0, &value));
+ return value * (1ll << 20);
+}
+
StreamExecutor::StreamExecutor(
const Platform *platform,
std::unique_ptr<internal::StreamExecutorInterface> implementation)
@@ -172,7 +182,9 @@ StreamExecutor::StreamExecutor(
background_threads_(new port::ThreadPool(
port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
live_stream_count_(0),
- tracing_enabled_(false) {
+ tracing_enabled_(false),
+ mem_alloc_bytes_(0),
+ memory_limit_bytes_(GetMemoryLimitBytes()) {
if (port::Lowercase(platform_->Name()) == "cuda") {
platform_kind_ = PlatformKind::kCuda;
} else if (port::Lowercase(platform_->Name()) == "opencl") {
@@ -460,6 +472,14 @@ port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
}
void *StreamExecutor::Allocate(uint64 size) {
+ if (memory_limit_bytes_ > 0 &&
+ mem_alloc_bytes_ + size > memory_limit_bytes_) {
+ LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
+ << device_ordinal_
+ << " within provided limit. [used=" << mem_alloc_bytes_
+ << ", limit=" << memory_limit_bytes_ << "]";
+ return nullptr;
+ }
void *buf = implementation_->Allocate(size);
VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
<< buf << StackTraceIfVLOG10();
@@ -779,6 +799,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
mutex_lock lock(mu_);
mem_allocs_[opaque] = AllocRecord{
bytes, ""};
+ mem_alloc_bytes_ += bytes;
}
}
@@ -789,6 +810,7 @@ void StreamExecutor::EraseAllocRecord(void *opaque) {
LOG(ERROR) << "Deallocating unknown pointer: "
<< port::Printf("0x%p", opaque);
} else {
+ mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
mem_allocs_.erase(opaque);
}
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 437f298616..4a8a270afa 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <tuple>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -81,8 +82,8 @@ class StreamExecutor {
port::Status Init();
port::Status Init(int device_ordinal, DeviceOptions device_options);
- // DEPRECATED: Do not use; use platform() instead.
// Returns the platform that this StreamExecutor is acting upon.
+ ABSL_DEPRECATED("Use platform() instead.")
PlatformKind platform_kind() const { return platform_kind_; }
// Returns a reference to the platform that created this executor.
@@ -255,15 +256,15 @@ class StreamExecutor {
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the host source to the device destination.
- //
- // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
uint64 size) SE_MUST_USE_RESULT;
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the device source to the host destination.
- //
- // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
uint64 size) SE_MUST_USE_RESULT;
@@ -699,6 +700,13 @@ class StreamExecutor {
// The set of TraceListeners registered for this StreamExecutor.
std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
+ // Allocated memory in bytes.
+ int64 mem_alloc_bytes_;
+
+ // Memory limit in bytes. Value less or equal to 0 indicates there is no
+ // limit.
+ int64 memory_limit_bytes_;
+
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
};
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index adac895a17..7ddaf7806e 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -448,7 +448,7 @@ def tf_gen_op_wrapper_cc(
tf_cc_binary(
name = tool,
copts = tf_copts(),
- linkopts = if_not_windows(["-lm"]),
+ linkopts = if_not_windows(["-lm", "-Wl,-ldl"]),
linkstatic = 1, # Faster to link this one-time-use binary dynamically
deps = [op_gen] + deps,
)
@@ -602,6 +602,7 @@ def tf_gen_op_wrappers_cc(
# is invalid to specify both "hidden" and "op_whitelist".
# cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the
# specified ops.
+
def tf_gen_op_wrapper_py(
name,
out = None,
@@ -623,7 +624,7 @@ def tf_gen_op_wrapper_py(
deps = [str(Label("//tensorflow/core:" + name + "_op_lib"))]
tf_cc_binary(
name = tool_name,
- linkopts = if_not_windows(["-lm"]) + cc_linkopts,
+ linkopts = if_not_windows(["-lm", "-Wl,-ldl"]) + cc_linkopts,
copts = tf_copts(),
linkstatic = 1, # Faster to link this one-time-use binary dynamically
deps = ([
@@ -1215,9 +1216,11 @@ def tf_mkl_kernel_library(
if prefix:
srcs = srcs + native.glob(
[prefix + "*.cc"],
+ exclude = [prefix + "*test*"],
)
hdrs = hdrs + native.glob(
[prefix + "*.h"],
+ exclude = [prefix + "*test*"],
)
# -fno-exceptions in nocopts breaks compilation if header modules are enabled.
@@ -1674,7 +1677,7 @@ def py_test(deps = [], data = [], kernels = [], **kwargs):
deps = select({
"//conditions:default": deps,
clean_dep("//tensorflow:no_tensorflow_py_deps"): [],
- }) + tf_binary_dynamic_kernel_deps(kernels),
+ }),
data = data + select({
"//conditions:default": [],
clean_dep("//tensorflow:no_tensorflow_py_deps"): ["//tensorflow/tools/pip_package:win_pip_package_marker"],
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 87745420ee..c3ba2dba57 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -111,6 +111,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 6dd46365b0..3541671bee 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index 35b7105eba..b113c18ee0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 8ae370af98..7210bf5db4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 7027e78df4..ef3409b1b5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesClassifier"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,14 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index d8167ea7cb..775130468f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesRegressor"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,14 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
index 24a58fb118..f06e798953 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.feature_column.pbtxt
@@ -34,7 +34,7 @@ tf_module {
}
member_method {
name: "input_layer"
- argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
+ argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'None\'], "
}
member_method {
name: "linear_model"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index d843194ef0..0869de0243 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -151,7 +151,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -219,7 +219,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index b8e9baca71..20f39fae1e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -156,7 +156,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -164,7 +164,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
index 73b577da37..a296e13158 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
@@ -105,6 +105,10 @@ tf_module {
argspec: "args=[\'metric\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "sparse_categorical_accuracy"
+ argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "sparse_categorical_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index 472b9818df..4011719317 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -151,7 +151,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -219,7 +219,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index 937516eff1..8a12ac1ad8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -156,7 +156,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -164,7 +164,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index dd9f7c49e0..503e145a91 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1093,6 +1093,10 @@ tf_module {
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "extract_volume_patches"
+ argspec: "args=[\'input\', \'ksizes\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
@@ -1373,6 +1377,10 @@ tf_module {
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "load_library"
+ argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "load_op_library"
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
@@ -1426,7 +1434,7 @@ tf_module {
}
member_method {
name: "map_fn"
- argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], "
+ argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], "
}
member_method {
name: "matching_files"
@@ -1589,6 +1597,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -1797,6 +1809,10 @@ tf_module {
argspec: "args=[\'ref\', \'indices\', \'updates\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
+ name: "searchsorted"
+ argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 018be7b9f9..c52581dec1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -1,12 +1,16 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "length"
- argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "regex_full_match"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 87745420ee..c3ba2dba57 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -111,6 +111,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index 6dd46365b0..3541671bee 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 35b7105eba..b113c18ee0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 8ae370af98..7210bf5db4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -112,6 +112,10 @@ tf_class {
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 7027e78df4..ef3409b1b5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesClassifier"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,14 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index d8167ea7cb..775130468f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.estimator.BoostedTreesRegressor"
tf_class {
is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
+ is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees._BoostedTreesBase\'>"
is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
is_instance: "<type \'object\'>"
member {
@@ -32,6 +33,14 @@ tf_class {
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_feature_importances"
+ argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "experimental_predict_with_explanations"
+ argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
index 24a58fb118..f06e798953 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
@@ -34,7 +34,7 @@ tf_module {
}
member_method {
name: "input_layer"
- argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
+ argspec: "args=[\'features\', \'feature_columns\', \'weight_collections\', \'trainable\', \'cols_to_vars\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'None\'], "
}
member_method {
name: "linear_model"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index d843194ef0..0869de0243 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -151,7 +151,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -219,7 +219,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index b8e9baca71..20f39fae1e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -156,7 +156,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -164,7 +164,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
index 73b577da37..a296e13158 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
@@ -105,6 +105,10 @@ tf_module {
argspec: "args=[\'metric\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "sparse_categorical_accuracy"
+ argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "sparse_categorical_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 472b9818df..4011719317 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -151,7 +151,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -159,7 +159,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -219,7 +219,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_generator"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index 937516eff1..8a12ac1ad8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -156,7 +156,7 @@ tf_class {
}
member_method {
name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "evaluate_generator"
@@ -164,7 +164,7 @@ tf_class {
}
member_method {
name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "fit_generator"
@@ -228,7 +228,7 @@ tf_class {
}
member_method {
name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \'10\', \'1\', \'False\'], "
}
member_method {
name: "predict_classes"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 9332e16bf6..96212f5528 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -581,10 +581,6 @@ tf_module {
argspec: "args=[\'op_type\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "Print"
- argspec: "args=[\'input_\', \'data\', \'message\', \'first_n\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "abs"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1041,6 +1037,10 @@ tf_module {
argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "extract_volume_patches"
+ argspec: "args=[\'input\', \'ksizes\', \'strides\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
@@ -1321,6 +1321,10 @@ tf_module {
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "load_library"
+ argspec: "args=[\'library_location\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "load_op_library"
argspec: "args=[\'library_filename\'], varargs=None, keywords=None, defaults=None"
}
@@ -1374,7 +1378,7 @@ tf_module {
}
member_method {
name: "map_fn"
- argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], "
+ argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], "
}
member_method {
name: "matching_files"
@@ -1537,6 +1541,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "print"
+ argspec: "args=[], varargs=inputs, keywords=kwargs, defaults=None"
+ }
+ member_method {
name: "py_func"
argspec: "args=[\'func\', \'inp\', \'Tout\', \'stateful\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -1721,6 +1729,10 @@ tf_module {
argspec: "args=[\'indices\', \'updates\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "searchsorted"
+ argspec: "args=[\'sorted_sequence\', \'values\', \'side\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'left\', \"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 018be7b9f9..c52581dec1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -1,12 +1,16 @@
path: "tensorflow.strings"
tf_module {
member_method {
+ name: "format"
+ argspec: "args=[\'template\', \'inputs\', \'placeholder\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'{}\', \'3\', \'None\'], "
+ }
+ member_method {
name: "join"
argspec: "args=[\'inputs\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "length"
- argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "regex_full_match"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index b21dabbde7..cb6da5088b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -265,10 +265,6 @@ tf_module {
argspec: "args=[\'graph\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "do_quantize_training_on_graphdef"
- argspec: "args=[\'input_graph\', \'num_bits\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "exponential_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 99bed5714f..d06c7f2d49 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -174,7 +174,7 @@ class ApiCompatibilityTest(test.TestCase):
verbose_diff_message = diff_message
else:
# Do not truncate diff
- self.maxDiffs = None # pylint: disable=invalid-name
+ self.maxDiff = None # pylint: disable=invalid-name
# Now we can run an actual proto diff.
try:
self.assertProtoEquals(expected_dict[key], actual_dict[key])
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
index e026edb6bb..0a55b84ac4 100644
--- a/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu.ppc64le
@@ -1,4 +1,4 @@
-FROM nvidia/cuda-ppc64le:9.0-cudnn7-devel-ubuntu16.04
+FROM nvidia/cuda-ppc64le:9.2-cudnn7-devel-ubuntu16.04
LABEL maintainer="William Irons <wdirons@us.ibm.com>"
@@ -26,6 +26,8 @@ ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
# Configure the build for our CUDA configuration.
ENV TF_NEED_CUDA 1
ENV TF_CUDA_COMPUTE_CAPABILITIES 3.0
+ENV TF_CUDA_VERSION 9.2
+ENV CUDA_TOOLKIT_PATH /usr/local/cuda-9.2
# TODO get NCCL 2 in the docker image
ENV TF_NCCL_VERSION 1
diff --git a/tensorflow/tools/ci_build/README.md b/tensorflow/tools/ci_build/README.md
index f2161b700a..e2fd977f50 100644
--- a/tensorflow/tools/ci_build/README.md
+++ b/tensorflow/tools/ci_build/README.md
@@ -24,7 +24,7 @@ natively on your system.
### Run TensorFlow CI Scripts Natively on your Machine
-1. Follow the instructions at https://www.tensorflow.org/install/install_sources,
+1. Follow the instructions at https://www.tensorflow.org/install/source,
but stop when you get to the section "Configure the installation". You do not
need to configure the installation to run the CI scripts.
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index bbaf59c69a..17198a6560 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -64,7 +64,7 @@ while true; do
fi
done
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
# PIP tests should have a "different" path. Different than the one we place
# virtualenv, because we are deleting and recreating it here.
@@ -76,7 +76,7 @@ ln -s $(pwd)/tensorflow ${PIP_TEST_ROOT}/tensorflow
# Do not run tests with "no_pip" tag. If running GPU tests, also do not run
# tests with no_pip_gpu tag.
-PIP_TEST_FILTER_TAG="-no_pip,-no_oss"
+PIP_TEST_FILTER_TAG="-no_pip,-no_oss,-benchmark-test"
if [[ ${IS_OSS_SERIAL} == "1" ]]; then
PIP_TEST_FILTER_TAG="$(echo "${PIP_TEST_FILTER_TAG}" | sed s/-no_oss//)"
PIP_TEST_FILTER_TAG="${PIP_TEST_FILTER_TAG},oss_serial"
@@ -85,7 +85,7 @@ else
fi
if [[ ${IS_GPU} == "1" ]]; then
- PIP_TEST_FILTER_TAG="-no_pip_gpu,${PIP_TEST_FILTER_TAG}"
+ PIP_TEST_FILTER_TAG="-no_gpu,-no_pip_gpu,${PIP_TEST_FILTER_TAG}"
fi
if [[ ${IS_MAC} == "1" ]]; then
PIP_TEST_FILTER_TAG="-nomac,${PIP_TEST_FILTER_TAG}"
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index cc09784c1d..49a9048c03 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -147,7 +147,7 @@ PIP_INTEGRATION_TESTS_FLAG="--integration_tests"
ANDROID_CMD="${CI_BUILD_DIR}/builds/android.sh"
ANDROID_FULL_CMD="${CI_BUILD_DIR}/builds/android_full.sh"
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
PARALLEL_GPU_TEST_CMD='//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute'
BENCHMARK_CMD="${CI_BUILD_DIR}/builds/benchmark.sh"
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index 75da9bb835..cd7206baf8 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -16,29 +16,25 @@
#
#
# A script to run multiple GPU tests in parallel controlled with an environment
-# variable. This script will assume that when it runs, one of the locks are
-# already released. So the program calling this script is expected to make sure
-# that only $TF_GPU_COUNT processes are running at any gien time.
+# variable.
#
# Required environment variables:
-# TF_GPU_COUNT = Number of GPUs available. This HAS TO BE IN SYNC with the
-# value of --local_test_jobs flag for bazel.
+# TF_GPU_COUNT = Number of GPUs available.
-BASH_VER_MAJOR=$(echo ${BASH_VERSION} | cut -d '.' -f 1)
-BASH_VER_MINOR=$(echo ${BASH_VERSION} | cut -d '.' -f 2)
-
-if [[ ${BASH_VER_MAJOR} -lt 4 ]]; then
- echo "Insufficient bash version: ${BASH_VERSION} < 4.2" >&2
- exit 1
-elif [[ ${BASH_VER_MAJOR} -eq 4 ]] && [[ ${BASH_VER_MINOR} -lt 2 ]]; then
- echo "Insufficient bash version: ${BASH_VERSION} < 4.2" >&2
- exit 1
-fi
-
-function is_absolute {
- [[ "$1" = /* ]] || [[ "$1" =~ ^[a-zA-Z]:[/\\].* ]]
-}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
+TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-8}
+# We want to allow running one of the following configs:
+# - 4 tests per GPU on k80
+# - 8 tests per GPU on p100
+# p100 has minimum 12G memory. Therefore, we should limit each test to 1.5G.
+# To leave some room in case we want to run more tests in parallel in the
+# future and to use a rounder number, we set it to 1G.
+export TF_PER_DEVICE_MEMORY_LIMIT_MB=1024
+# *******************************************************************
+# This section of the script is needed to
+# make things work on windows under msys.
+# *******************************************************************
RUNFILES_MANIFEST_FILE="${TEST_SRCDIR}/MANIFEST"
function rlocation() {
if is_absolute "$1" ; then
@@ -55,29 +51,32 @@ function rlocation() {
TEST_BINARY="$(rlocation $TEST_WORKSPACE/${1#./})"
shift
+# *******************************************************************
-# Make sure /var/lock exists, this may not be true under MSYS
mkdir -p /var/lock
-
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
-
-for i in `seq 0 $((TF_GPU_COUNT-1))`; do
- exec {lock_fd}>/var/lock/gpulock$i || exit 1
- if flock -n "$lock_fd";
- then
- (
- # This export only works within the brackets, so it is isolated to one
- # single command.
- export CUDA_VISIBLE_DEVICES=$i
- echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES"
- "$TEST_BINARY" $@
- )
- return_code=$?
- flock -u "$lock_fd"
- exit $return_code
- fi
+# Try to acquire any of the TF_GPU_COUNT * TF_TESTS_PER_GPU
+# slots to run a test at.
+#
+# Prefer to allocate 1 test per GPU over 4 tests on 1 GPU.
+# So, we iterate over TF_TESTS_PER_GPU first.
+for j in `seq 0 $((TF_TESTS_PER_GPU-1))`; do
+ for i in `seq 0 $((TF_GPU_COUNT-1))`; do
+ exec {lock_fd}>/var/lock/gpulock${i}_${j} || exit 1
+ if flock -n "$lock_fd";
+ then
+ (
+ # This export only works within the brackets, so it is isolated to one
+ # single command.
+ export CUDA_VISIBLE_DEVICES=$i
+ echo "Running test $TEST_BINARY $* on GPU $CUDA_VISIBLE_DEVICES"
+ "$TEST_BINARY" $@
+ )
+ return_code=$?
+ flock -u "$lock_fd"
+ exit $return_code
+ fi
+ done
done
echo "Cannot find a free GPU to run the test $* on, exiting with failure..."
exit 1
-
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index a9ae715c6a..4ced96f90b 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -68,8 +68,8 @@ else
pip3 install --upgrade numpy==1.14.5
fi
-pip2 install scipy==0.18.1
-pip3 install scipy==0.18.1
+pip2 install scipy==1.1.0
+pip3 install scipy==1.1.0
pip2 install scikit-learn==0.18.1
pip3 install scikit-learn==0.18.1
diff --git a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
index 2a9f295188..7be5f454ec 100755
--- a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
+++ b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh
@@ -33,7 +33,7 @@ yes "" | $PYTHON_BIN_PATH configure.py
# Setting KMP_BLOCKTIME to 0 lets OpenMP threads to sleep right after parallel execution
# in an MKL primitive. This reduces the effects of an oversubscription of OpenMP threads
# caused by executing multiple tests concurrently.
-bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \
+bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=cc,py -k \
--jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \
--config=mkl --test_env=KMP_BLOCKTIME=0 --config=opt --test_output=errors -- \
//tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/...
diff --git a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
index 28d5565b98..34847e637a 100644
--- a/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/gpu/pip/build_tf_windows.sh
@@ -122,7 +122,7 @@ fi
PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow_gpu-*.whl)
reinstall_tensorflow_pip ${PIP_NAME}
-TF_GPU_COUNT=${TF_GPU_COUNT:-8}
+TF_GPU_COUNT=${TF_GPU_COUNT:-4}
# Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
# which will result testing system installed tensorflow
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
index 35a74c9664..68ba7a2630 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v0_11.py
@@ -94,7 +94,7 @@ class TestUpgrade(test_util.TensorFlowTestCase):
self.assertAllClose(
tf.reduce_logsumexp(a, [0, 1]).eval(), 6.45619344711)
self.assertAllEqual(
- tf.expand_dims([[1, 2], [3, 4]], dim=1).eval(),
+ tf.expand_dims([[1, 2], [3, 4]], axis=1).eval(),
[[[1, 2]], [[3, 4]]])
def testArgMinMax(self):
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 38216ce9b1..53c546b10c 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -120,10 +120,18 @@ Simple usage:
report_filename = args.report_filename
files_processed = 0
if args.input_file:
+ if not args.output_file:
+ raise ValueError(
+ "--outfile=<output file> argument is required when converting a "
+ "single file.")
files_processed, report_text, errors = upgrade.process_file(
args.input_file, args.output_file)
files_processed = 1
elif args.input_tree:
+ if not args.output_tree:
+ raise ValueError(
+ "--outtree=<output directory> argument is required when converting a "
+ "file tree.")
files_processed, report_text, errors = upgrade.process_tree(
args.input_tree, args.output_tree, args.copy_other_files)
else:
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
index 228d5ee35d..f8ed74aaf7 100644
--- a/tensorflow/tools/dist_test/README.md
+++ b/tensorflow/tools/dist_test/README.md
@@ -23,7 +23,7 @@ You can test specify version of TensorFlow:
./local_test.sh ${whl_file_url}
```
-For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/install_linux#the_url_of_the_tensorflow_python_package) for Ubuntu.
+For example, you can find these TensorFlow python package URLs from [here](https://www.tensorflow.org/install/pip) for Ubuntu.
**2) Launch a remote k8s cluster on Google Kubernetes Engine (GKE) and run the
test suite on it**
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 39e7bc8b66..c741e8ad0c 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -78,7 +78,7 @@ RUN mkdir /bazel && \
# Download and build TensorFlow.
WORKDIR /tensorflow
-RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.11 --depth=1 https://github.com/tensorflow/tensorflow.git .
# TODO(craigcitro): Don't install the pip package, since it makes it
# more difficult to experiment with local changes. Instead, just add
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index e487779e7a..f544725af4 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -100,7 +100,7 @@ RUN mkdir /bazel && \
# Download and build TensorFlow.
WORKDIR /tensorflow
-RUN git clone --branch=r1.10 --depth=1 https://github.com/tensorflow/tensorflow.git .
+RUN git clone --branch=r1.11 --depth=1 https://github.com/tensorflow/tensorflow.git .
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index 371451d2aa..db7c701289 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -3,7 +3,7 @@ FROM ubuntu:16.04
LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
# These parameters can be overridden by parameterized_docker_build.sh
-ARG TF_BUILD_VERSION=r1.10
+ARG TF_BUILD_VERSION=r1.11
ARG PYTHON="python"
ARG PYTHON3_DEV=""
ARG WHL_DIR="/tmp/pip"
diff --git a/tensorflow/tools/docker/jupyter_notebook_config.py b/tensorflow/tools/docker/jupyter_notebook_config.py
index 05dcefb099..4449e3501f 100644
--- a/tensorflow/tools/docker/jupyter_notebook_config.py
+++ b/tensorflow/tools/docker/jupyter_notebook_config.py
@@ -16,7 +16,7 @@ import os
from IPython.lib import passwd
c = c # pylint:disable=undefined-variable
-c.NotebookApp.ip = '*'
+c.NotebookApp.ip = '0.0.0.0' # https://github.com/jupyter/notebook/issues/3946
c.NotebookApp.port = int(os.getenv('PORT', 8888))
c.NotebookApp.open_browser = False
diff --git a/tensorflow/tools/docker/parameterized_docker_build.sh b/tensorflow/tools/docker/parameterized_docker_build.sh
index 448a3a7647..570aa8278c 100755
--- a/tensorflow/tools/docker/parameterized_docker_build.sh
+++ b/tensorflow/tools/docker/parameterized_docker_build.sh
@@ -244,7 +244,7 @@ if [[ "${TF_DOCKER_BUILD_IS_DEVEL}" == "no" ]]; then
if [[ "${TF_DOCKER_BUILD_TYPE}" == "gpu" ]]; then
export TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS=\
- "${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2"
+ "${TF_BUILD_APPEND_CI_DOCKER_EXTRA_PARAMS} -e TF_CUDA_COMPUTE_CAPABILITIES=3.0,3.5,5.2,6.0"
fi
pushd "${SCRIPT_DIR}/../../../"
diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md
index d64db35afb..5996573cf1 100644
--- a/tensorflow/tools/dockerfiles/README.md
+++ b/tensorflow/tools/dockerfiles/README.md
@@ -34,13 +34,13 @@ documentation](https://docs.docker.com/engine/reference/run/).
# User permissions (-u) are required if you use (-v).
# CPU-based images
-$ docker run -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+$ docker run -u $(id -u):$(id -g) -v $(pwd):/my-devel -it tf
# GPU-based images (set up nvidia-docker2 first)
-$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(PWD):/my-devel -it tf
+$ docker run --runtime=nvidia -u $(id -u):$(id -g) -v $(pwd):/my-devel -it tf
# Images with Jupyter run on port 8888, and needs a volume for notebooks
-$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(PWD):/notebooks -it tf
+$ docker run --user $(id -u):$(id -g) -p 8888:8888 -v $(pwd):/notebooks -it tf
```
These images do not come with the TensorFlow source code -- but the development
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index 4f7efe193f..b218e900bf 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -91,9 +91,10 @@ py_binary(
":parser",
":pretty_docs",
":py_guide_parser",
- "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
+ "//tensorflow/python:util",
"//tensorflow/tools/common:public_api",
"//tensorflow/tools/common:traverse",
+ "@six_archive//:six",
],
)
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index 1cd9cb7ca9..77a3ca2052 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -453,7 +453,11 @@ def update_id_tags_inplace(src_dir):
EXCLUDED = set(['__init__.py', 'OWNERS', 'README.txt'])
-def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
+def replace_refs(src_dir,
+ output_dir,
+ reference_resolver,
+ file_pattern='*.md',
+ api_docs_relpath='api_docs'):
"""Fix @{} references in all files under `src_dir` matching `file_pattern`.
A matching directory structure, with the modified files is
@@ -472,12 +476,13 @@ def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
reference_resolver: A `parser.ReferenceResolver` to make the replacements.
file_pattern: Only replace references in files matching file_patters,
using fnmatch. Non-matching files are copied unchanged.
+ api_docs_relpath: Relative-path string to the api_docs, from the src_dir.
"""
# Iterate through all the source files and process them.
for dirpath, _, filenames in os.walk(src_dir):
+ depth = os.path.relpath(src_dir, start=dirpath)
# How to get from `dirpath` to api_docs/python/
- relative_path_to_root = os.path.relpath(
- path=os.path.join(src_dir, 'api_docs/python'), start=dirpath)
+ relative_path_to_root = os.path.join(depth, api_docs_relpath, 'python')
# Make the directory under output_dir.
new_dir = os.path.join(output_dir,
@@ -497,7 +502,8 @@ def replace_refs(src_dir, output_dir, reference_resolver, file_pattern='*.md'):
full_out_path = os.path.join(output_dir, suffix)
# Copy files that do not match the file_pattern, unmodified.
if not fnmatch.fnmatch(base_name, file_pattern):
- shutil.copyfile(full_in_path, full_out_path)
+ if full_in_path != full_out_path:
+ shutil.copyfile(full_in_path, full_out_path)
continue
with open(full_in_path, 'rb') as f:
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index a6159fa692..83b4bf8128 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -1479,7 +1479,7 @@ class ParserConfig(object):
self.base_dir = base_dir
self.defined_in_prefix = 'tensorflow/'
self.code_url_prefix = (
- 'https://www.tensorflow.org/code/tensorflow/') # pylint: disable=line-too-long
+ '/code/stable/tensorflow/') # pylint: disable=line-too-long
def py_name_to_object(self, full_name):
"""Return the Python object for a Python symbol name."""
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 91c5cd094c..12354a6ab2 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -60,18 +60,9 @@ COMMON_PIP_DEPS = [
":included_headers",
"//tensorflow:tensorflow_py",
"//tensorflow/contrib/autograph:autograph",
- "//tensorflow/contrib/autograph/converters:converters",
- "//tensorflow/contrib/autograph/core:core",
- "//tensorflow/contrib/autograph/core:test_lib",
- "//tensorflow/contrib/autograph/impl:impl",
- "//tensorflow/contrib/autograph/lang:lang",
- "//tensorflow/contrib/autograph/operators:operators",
- "//tensorflow/contrib/autograph/pyct:pyct",
- "//tensorflow/contrib/autograph/pyct/testing:testing",
- "//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
- "//tensorflow/contrib/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
+ "//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
@@ -102,6 +93,16 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/timeseries:timeseries_pip",
"//tensorflow/contrib/tpu",
"//tensorflow/examples/tutorials/mnist:package",
+ # "//tensorflow/python/autograph/converters:converters",
+ # "//tensorflow/python/autograph/core:core",
+ "//tensorflow/python/autograph/core:test_lib",
+ # "//tensorflow/python/autograph/impl:impl",
+ # "//tensorflow/python/autograph/lang:lang",
+ # "//tensorflow/python/autograph/operators:operators",
+ # "//tensorflow/python/autograph/pyct:pyct",
+ # "//tensorflow/python/autograph/pyct/testing:testing",
+ # "//tensorflow/python/autograph/pyct/static_analysis:static_analysis",
+ "//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/python:cond_v2",
"//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:meta_graph_testdata",
@@ -114,6 +115,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/tools:tools_pip",
"//tensorflow/python/tools/api/generator:create_python_api",
"//tensorflow/python:test_ops",
+ "//tensorflow/python:while_v2",
"//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
]
@@ -210,6 +212,7 @@ filegroup(
"@ngraph//:LICENSE",
"@ngraph_tf//:LICENSE",
"@nlohmann_json_lib//:LICENSE.MIT",
+ "@tbb//:LICENSE",
]) + tf_additional_license_deps(),
)
diff --git a/tensorflow/tools/pip_package/build_pip_package.sh b/tensorflow/tools/pip_package/build_pip_package.sh
index 666ea75d46..c62271c5cb 100755
--- a/tensorflow/tools/pip_package/build_pip_package.sh
+++ b/tensorflow/tools/pip_package/build_pip_package.sh
@@ -43,8 +43,7 @@ function cp_external() {
PLATFORM="$(uname -s | tr 'A-Z' 'a-z')"
function is_windows() {
- # On windows, the shell script is actually running in msys
- if [[ "${PLATFORM}" =~ (mingw64|msys)_nt* ]]; then
+ if [[ "${PLATFORM}" =~ (cygwin|mingw32|mingw64|msys)_nt* ]]; then
true
else
false
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 3102239a19..b95e1f5c87 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -45,7 +45,7 @@ DOCLINES = __doc__.split('\n')
# This version string is semver compatible, but incompatible with pip.
# For pip, we will remove all '-' characters from this string, and use the
# result for pip.
-_VERSION = '1.10.0'
+_VERSION = '1.11.0-rc1'
REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
@@ -57,7 +57,7 @@ REQUIRED_PACKAGES = [
'six >= 1.10.0',
'protobuf >= 3.6.0',
'setuptools <= 39.1.0',
- 'tensorboard >= 1.10.0, < 1.11.0',
+ 'tensorboard >= 1.11.0, < 1.12.0',
'termcolor >= 1.1.0',
]
@@ -86,7 +86,7 @@ else:
if 'tf_nightly' in project_name:
for i, pkg in enumerate(REQUIRED_PACKAGES):
if 'tensorboard' in pkg:
- REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.11.0a0, < 1.12.0a0'
+ REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.12.0a0, < 1.13.0a0'
break
# weakref.finalize and enum were introduced in Python 3.4
diff --git a/tensorflow/tools/test/check_futures_test.py b/tensorflow/tools/test/check_futures_test.py
index 9181c9bd4a..a883ce221f 100644
--- a/tensorflow/tools/test/check_futures_test.py
+++ b/tensorflow/tools/test/check_futures_test.py
@@ -37,6 +37,7 @@ BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
FUTURES_PATTERN = re.compile(r'^from __future__ import (\w+)\s*$')
FUTURES_PATTERN_2 = re.compile(
r'^from __future__ import (\w+), (\w+), (\w+)\s*$')
+FUTURES_PATTERN_3 = re.compile(r'^from __future__ import (\w+) as \w+\s*$')
REQUIRED_FUTURES = frozenset(['absolute_import', 'division', 'print_function'])
WHITELIST = [
@@ -59,6 +60,8 @@ def check_file(path, old_division):
for line in open(path, encoding='utf-8') if six.PY3 else open(path):
count += 1
m = FUTURES_PATTERN.match(line)
+ if not m:
+ m = FUTURES_PATTERN_3.match(line)
if m:
futures.add(m.group(1))
else:
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8e6f4143a9..915fee6a1f 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -106,11 +106,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_absl",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/fb462224c058487763f263b7995d70efd0242c17.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/e01d95528ea2137a4a27a88d1f57c6cb260aafed.tar.gz",
],
- sha256 = "f4f34f90083d5259f9a1a4067749d842599748d8ca03c1d9fe723124a7045c63",
- strip_prefix = "abseil-cpp-fb462224c058487763f263b7995d70efd0242c17",
+ sha256 = "84043ed402d2a2a6ba4cdddb7e85118b1158fd81fe4ac3a14adc343d054c1e2e",
+ strip_prefix = "abseil-cpp-e01d95528ea2137a4a27a88d1f57c6cb260aafed",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
)
@@ -179,6 +179,10 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
sha256 = "fdd3b3aecce60987e5525e55bf3a21d68a8695320bd5b980775af6507eec3944",
strip_prefix = "google-cloud-cpp-14760a86c4ffab9943b476305c4fe927ad95db1c",
+ system_build_file = clean_dep("//third_party/systemlibs:google_cloud_cpp.BUILD"),
+ system_link_files = {
+ "//third_party/systemlibs:google_cloud_cpp.google.cloud.bigtable.BUILD": "google/cloud/bigtable/BUILD",
+ },
)
tf_http_archive(
@@ -190,6 +194,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
sha256 = "824870d87a176f26bcef663e92051f532fac756d1a06b404055dc078425f4378",
strip_prefix = "googleapis-f81082ea1e2f85c43649bee26e0d9871d4b41cdb",
build_file = clean_dep("//third_party:googleapis.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:googleapis.BUILD"),
)
tf_http_archive(
@@ -319,6 +324,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
sha256 = "7068908321ecd2774f145193c4b34a11305bd104b4551b09273dfd1d6a374930",
strip_prefix = "gast-0.2.0",
build_file = clean_dep("//third_party:gast.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:gast.BUILD"),
)
tf_http_archive(
@@ -341,6 +347,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
sha256 = "95160f778a62c7a60ddeadc7bf2d83f85a23a27359814aca12cf949e896fa82c",
strip_prefix = "abseil-py-pypi-v0.2.2",
+ system_build_file = clean_dep("//third_party/systemlibs:absl_py.BUILD"),
+ system_link_files = {
+ "//third_party/systemlibs:absl_py.absl.flags.BUILD": "absl/flags/BUILD",
+ "//third_party/systemlibs:absl_py.absl.testing.BUILD": "absl/testing/BUILD",
+ },
)
tf_http_archive(
@@ -491,11 +502,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/738b5f5028ef39cbb023967f80fa2e5dd568556b.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/7167e4d196a50f78abe8af6553c943d50b757a13.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/7167e4d196a50f78abe8af6553c943d50b757a13.tar.gz",
],
- sha256 = "2bda8dd724ab432c162fb6eace259ccf8a97f13cb627336611bff68da2f33ec2",
- strip_prefix = "llvm-738b5f5028ef39cbb023967f80fa2e5dd568556b",
+ sha256 = "11d933232b27531abc83592fc9f03e7f928e504c7d478eeaba51efa929a3d9df",
+ strip_prefix = "llvm-7167e4d196a50f78abe8af6553c943d50b757a13",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
)
@@ -531,6 +542,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
sha256 = "1188e29000013ed6517168600fc35a010d58c5d321846d6a6dfee74e4c788b45",
strip_prefix = "boringssl-7f634429a04abc48e2eb041c81c5235816c96514",
+ system_build_file = clean_dep("//third_party/systemlibs:boringssl.BUILD"),
)
tf_http_archive(
@@ -738,14 +750,16 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
)
- native.new_http_archive(
+ tf_http_archive(
name = "double_conversion",
urls = [
+ "https://mirror.bazel.build/github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
"https://github.com/google/double-conversion/archive/3992066a95b823efc8ccc1baf82a1cfc73f6e9b8.zip",
],
sha256 = "2f7fbffac0d98d201ad0586f686034371a6d152ca67508ab611adc2386ad30de",
strip_prefix = "double-conversion-3992066a95b823efc8ccc1baf82a1cfc73f6e9b8",
build_file = clean_dep("//third_party:double_conversion.BUILD"),
+ system_build_file = clean_dep("//third_party/systemlibs:double_conversion.BUILD"),
)
tf_http_archive(
@@ -831,13 +845,24 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
)
tf_http_archive(
+ name = "tbb",
+ urls = [
+ "https://mirror.bazel.build/github.com/01org/tbb/archive/tbb_2018.zip",
+ "https://github.com/01org/tbb/archive/tbb_2018.zip",
+ ],
+ sha256 = "724686f90bcda78f13b76f297d964008737ccd6399328143c1c0093e73ae6a13",
+ strip_prefix = "tbb-tbb_2018",
+ build_file = clean_dep("//third_party/ngraph:tbb.BUILD"),
+ )
+
+ tf_http_archive(
name = "ngraph",
urls = [
- "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
- "https://github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.8.0.tar.gz",
+ "https://github.com/NervanaSystems/ngraph/archive/v0.8.0.tar.gz",
],
- sha256 = "cb35d3d98836f615408afd18371fb13e3400711247e0d822ba7f306c45e9bb2c",
- strip_prefix = "ngraph-0.5.0",
+ sha256 = "a8cf3ef2d0e6d31b54eb33f6a9e795f562195ce5c2a857e729ca9c35241cc45c",
+ strip_prefix = "ngraph-0.8.0",
build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
)
@@ -855,11 +880,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "ngraph_tf",
urls = [
- "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
- "https://github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc1.tar.gz",
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.6.0.tar.gz",
+ "https://github.com/NervanaSystems/ngraph-tf/archive/v0.6.0.tar.gz",
],
- sha256 = "7919332cb15120101c3e05c1b969a5e029a6411581312583c8f80b6aaaa83072",
- strip_prefix = "ngraph-tf-0.3.0-rc1",
+ sha256 = "1f49391c02bef24872e9f85591e60e0e7eef12a337db71390444118049fe451f",
+ strip_prefix = "ngraph-tf-0.6.0",
build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
)